Add ability to prevent exception propagation if building as part of ORT when ORT has exceptions disabled (#368)

* Add ability to prevent exception propagation with top level try/catch hander macros.

If combined build with ORT has exceptions disabled in ORT but ort-ext has an operator that requires exceptions, we enable exceptions in ort-ext but prevent them propagating up via try/catch in the entry points that ORT can call
  - RegisterCustomOps
  - CustomOpBase constructor and Compute

Removed some places in CustomOpApi that threw is OpKernelInfo* was nullptr but standardizing all kernels to store the OpKernelInfo provided in the ctor.

Added unit tests
  - need to validate on more platforms and add CI for build where we don't want to allow exceptions to propagate

* Update pyop

* Update CMakeLists.txt

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Update includes/exceptions.h

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Update includes/exceptions.h

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Update includes/onnxruntime_customop.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Merge with main and update
Address PR comments
Fix some issues.

* Delete local file

* Fix pyop update

* Add CI
Address PR comments

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Scott McKay 2023-02-28 04:31:44 +10:00 коммит произвёл GitHub
Родитель b375cb57e6
Коммит 5e44a7c3c9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
77 изменённых файлов: 669 добавлений и 510 удалений

Просмотреть файл

@ -449,3 +449,32 @@ jobs:
parameters:
xcWorkspacePath: '$(Build.SourcesDirectory)/test/ios/OrtExtensionsUsage/OrtExtensionsUsage.xcworkspace'
scheme: 'OrtExtensionsUsage'
#####################################
# Linux prevent exception propagation
#####################################
- job: Linux_Prevent_Exception_Propagation
pool:
vmImage: 'ubuntu-latest'
steps:
# Simulate an embedded build as part of ORT with exceptions disabled by manually setting CMAKE_CXX_FLAGS and
# using _OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE. The build should re-enable exceptions within ort-ext
# but prevent them from propagating. Unit tests are run to validate this.
- script: '
./build_lib.sh --enable_cxx_tests --onnxruntime_version 1.14.0 --config RelWithDebInfo
--cmake_extra_defines
_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE=ON OCOS_ENABLE_CPP_EXCEPTIONS=OFF
CMAKE_CXX_FLAGS="-fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables"
'
displayName: Build ort-ext with exception propagation disabled
# As an extra validation check CMakeCache.txt as well
- script: |
grep "^_OCOS_PREVENT_EXCEPTION_PROPAGATION.*ON$" build/Linux/RelWithDebInfo/CMakeCache.txt
if [ $? -ne 0 ]; then
echo "Exception propogation was not enabled correctly."
exit 1
fi

Просмотреть файл

@ -150,6 +150,12 @@ endif()
# External dependencies
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake)
# PROJECT_IS_TOP_LEVEL is available since 3.21
get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
if(not_top AND ONNXRUNTIME_ROOT)
set(_ONNXRUNTIME_EMBEDDED TRUE)
endif()
if(OCOS_ENABLE_SELECTED_OPLIST)
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
@ -158,6 +164,48 @@ if(OCOS_ENABLE_SELECTED_OPLIST)
include(_selectedoplist)
endif()
set(_OCOS_EXCEPTIONS_REQUIRED OFF)
if (OCOS_ENABLE_GPT2_TOKENIZER OR
OCOS_ENABLE_WORDPIECE_TOKENIZER OR
OCOS_ENABLE_BLINGFIRE OR
OCOS_ENABLE_SPM_TOKENIZER OR
(OCOS_ENABLE_CV2 OR OCOS_ENABLE_OPENCV_CODECS OR OCOS_ENABLE_VISION))
set(_OCOS_EXCEPTIONS_REQUIRED ON)
endif()
# Special case an embedded build with ORT exceptions disabled but custom ops that require exceptions.
# Allow using an override so we can do a direct build of ort-ext in a CI without having to embed it in an ORT build.
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION OFF)
if (_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE)
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION ${_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE})
elseif(_ONNXRUNTIME_EMBEDDED AND onnxruntime_DISABLE_EXCEPTIONS AND _OCOS_EXCEPTIONS_REQUIRED)
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION ON)
endif()
if (_OCOS_PREVENT_EXCEPTION_PROPAGATION)
message(STATUS "Embedded build as part of ONNX Runtime with exceptions disabled. "
"Extensions will be built with exceptions enabled due to included custom ops "
"using 3rd party libraries that require exceptions.")
if (NOT OCOS_ENABLE_CPP_EXCEPTIONS)
message(WARNING "Enabling C++ exception support as custom ops included in the build require them to be enabled.")
set(OCOS_ENABLE_CPP_EXCEPTIONS ON)
endif()
# undo the flags that ORT has set to disable exceptions.
# see https://github.com/microsoft/onnxruntime/blob/b1abb8c656c597bf221bd85682ae3d9e350d9aba/cmake/adjust_global_compile_flags.cmake#L160-L169
if(MSVC)
string(REPLACE "/EHs-c-" "/EHsc" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
else()
string(REPLACE "-fno-exceptions" "-fexceptions" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE "-fno-unwind-tables" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE "-fno-asynchronous-unwind-tables" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
endif()
# the ort-ext code has to provide a barrier between the exception enabled custom op code and ORT.
add_compile_definitions(OCOS_PREVENT_EXCEPTION_PROPAGATION)
endif()
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS)
endif()
@ -182,14 +230,7 @@ endfunction()
# set default MSVC warning level to 3 for external dependencies
set_msvc_c_cpp_compiler_warning_level(3)
# PROJECT_IS_TOP_LEVEL is available until 3.21
get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
if(not_top AND ONNXRUNTIME_ROOT)
set(_ONNXRUNTIME_EMBEDDED TRUE)
endif()
include(ext_ortlib)
include(gsl)
macro(standardize_output_folder bin_target)
@ -214,7 +255,7 @@ endif()
# ### scan all source files
set(TARGET_SRC_NOEXCEPTION)
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "includes/*.h*")
if(OCOS_ENABLE_TF_STRING)
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
@ -553,7 +594,7 @@ endforeach()
if(OCOS_ENABLE_CTEST)
if (OCOS_ENABLE_SELECTED_OPLIST)
# currently the tests don't handle operator exclusion cleanly.
message(FATAL "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
message(FATAL_ERROR "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
endif()
# Enable CTest
@ -593,7 +634,8 @@ if(OCOS_ENABLE_CTEST)
target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
endif()
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS})
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main gmock_main
${ocos_libraries} ${LINUX_CC_FLAGS})
# Copy ONNXRuntime DLLs into bin folder for testing on Windows platform
if(WIN32)

0
build_lib.sh Normal file → Executable file
Просмотреть файл

99
includes/exceptions.h Normal file
Просмотреть файл

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
#if defined(__ANDROID__)
#include <android/log.h>
#else
#include <iostream>
#endif
#endif
#include <stdexcept>
#include "onnxruntime_c_api.h"
namespace OrtW {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string message, OrtErrorCode code) : message_{std::move(message)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char* what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
inline void PrintFinalMessage(const char* file, int line, const char* msg) {
#if defined(__ANDROID__)
__android_log_print(ANDROID_LOG_ERROR, "onnxruntime-extensions", "Exception in %s line %d: %s", file, line, msg);
#else
std::cerr << "Exception in " << file << " line " << line << ": " << msg << std::endl;
#endif
}
#endif
#ifdef OCOS_NO_EXCEPTIONS
#define ORTX_CXX_API_THROW(string, code) \
do { \
OrtW::PrintFinalMessage(__FILE__, __LINE__, OrtW::Exception(string, code).what()); \
abort(); \
} while (false)
#define OCOS_TRY if (true)
#define OCOS_CATCH(x) else if (false)
#define OCOS_RETHROW
// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
// in the body of the catch statements, it is necessary to wrap the body of the catch statement into
// a lambda function. otherwise the exception referred will be undefined and cause build break
#define OCOS_HANDLE_EXCEPTION(func)
#else
#define ORTX_CXX_API_THROW(string, code) \
throw OrtW::Exception(string, code)
#define OCOS_TRY try
#define OCOS_CATCH(x) catch (x)
#define OCOS_RETHROW throw;
#define OCOS_HANDLE_EXCEPTION(func) func()
#endif
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
if (status) {
std::string error_message = ort.GetErrorMessage(status);
OrtErrorCode error_code = ort.GetErrorCode(status);
ort.ReleaseStatus(status);
ORTX_CXX_API_THROW(std::move(error_message), error_code);
}
}
} // namespace OrtW
// macros to wrap entry points that ORT calls where we may need to prevent exceptions propagating upwards to ORT
#define OCOS_API_IMPL_BEGIN \
OCOS_TRY {
// if exceptions are disabled (a 3rd party library could throw so we need to handle that)
// or we're preventing exception propagation, log and abort().
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
#define OCOS_API_IMPL_END \
} \
OCOS_CATCH(const std::exception& ex) { \
OCOS_HANDLE_EXCEPTION([&]() { \
OrtW::PrintFinalMessage(__FILE__, __LINE__, ex.what()); \
abort(); \
}); \
}
#else
// rethrow.
#define OCOS_API_IMPL_END \
} \
OCOS_CATCH(const std::exception&) { \
OCOS_HANDLE_EXCEPTION([&]() { \
OCOS_RETHROW; \
}); \
}
#endif

Просмотреть файл

@ -18,28 +18,29 @@ constexpr const char* c_OpDomain = "ai.onnx.contrib";
constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
struct BaseKernel {
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) {
}
bool HasAttribute(const char* name) const;
bool HasAttribute(const char* name) const noexcept;
template <class T>
bool TryToGetAttribute(const char* name, T& value);
bool TryToGetAttribute(const char* name, T& value) const noexcept;
template <class T>
T TryToGetAttributeWithDefault(const char* name, T default_value) {
T TryToGetAttributeWithDefault(const char* name, T default_value) const noexcept {
T& result = default_value;
TryToGetAttribute(name, result);
return result;
}
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
const std::vector<int64_t>& data);
protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
const OrtApi& api_;
OrtW::CustomOpApi ort_;
const OrtKernelInfo* info_;
const OrtKernelInfo& info_;
};
struct OrtTensorDimensions : std::vector<int64_t> {

Просмотреть файл

@ -9,56 +9,17 @@
#include <cstddef>
#include <array>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
#ifdef ORT_NO_EXCEPTIONS
#include <cstdio>
#endif
#include "onnxruntime_c_api.h"
#include "exceptions.h"
namespace OrtW {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char* what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
#ifdef ORT_NO_EXCEPTIONS
#define ORTX_CXX_API_THROW(string, code) \
do { \
fprintf(stderr, "%s\n", \
OrtW::Exception(string, code).what()); \
abort(); \
} while (false)
#else
#define ORTX_CXX_API_THROW(string, code) \
throw OrtW::Exception(string, code)
#endif
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
if (status) {
std::string error_message = ort.GetErrorMessage(status);
OrtErrorCode error_code = ort.GetErrorCode(status);
ort.ReleaseStatus(status);
ORTX_CXX_API_THROW(std::move(error_message), error_code);
}
}
//
// Custom OPs (only needed to implement custom OPs)
//
@ -72,7 +33,8 @@ struct CustomOpApi {
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const;
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
size_t dim_values_length) const;
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
template <typename T>
@ -85,7 +47,8 @@ struct CustomOpApi {
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const;
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
size_t dim_count) const;
void ThrowOnError(OrtStatus* status) const {
OrtW::ThrowOnError(api_, status);
@ -99,18 +62,44 @@ template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = 10; // The minimum ORT version supported
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
void* result = nullptr;
OCOS_API_IMPL_BEGIN
result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
OCOS_API_IMPL_END
return result;
};
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetName();
};
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetExecutionProviderType();
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetInputTypeCount();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetInputType(index);
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetOutputTypeCount();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetOutputType(index);
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
OCOS_API_IMPL_BEGIN
static_cast<TKernel*>(op_kernel)->Compute(context);
OCOS_API_IMPL_END
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
@ -119,26 +108,29 @@ struct CustomOpBase : OrtCustomOp {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
};
}
template <typename... Args>
TKernel* CreateKernelImpl(Args&&... args) const {
// default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
// OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
// calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
return new TKernel(std::forward<Args>(args)...);
return new TKernel(api, info);
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api);
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; }

Просмотреть файл

@ -13,7 +13,7 @@
#include <cstdint>
struct KernelImageDecoder : BaseKernel {
KernelImageDecoder(const OrtApi& api) : BaseKernel(api) {}
KernelImageDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
void Compute(OrtKernelContext* context) {
// Setup inputs
@ -33,7 +33,7 @@ struct KernelImageDecoder : BaseKernel {
// Decode the image
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
// Setup output & copy to destination
@ -41,18 +41,14 @@ struct KernelImageDecoder : BaseKernel {
const int64_t colors = 3;
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
OrtValue *const output_value = ort_.KernelContext_GetOutput(
context, 0, output_dimensions.data(), output_dimensions.size());
OrtValue* const output_value = ort_.KernelContext_GetOutput(
context, 0, output_dimensions.data(), output_dimensions.size());
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
}
};
struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelImageDecoder(api);
}
const char* GetName() const {
return "ImageDecoder";
}

Просмотреть файл

@ -3,9 +3,8 @@
#include "string_tensor.h"
struct KernelImageReader : BaseKernel {
KernelImageReader(const OrtApi& api) : BaseKernel(api) {
KernelImageReader(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
@ -45,7 +44,7 @@ struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImage
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
}
const char* GetName() const{
const char* GetName() const {
return "ImageReader";
}
};

Просмотреть файл

@ -1,9 +1,8 @@
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
struct KernelGaussianBlur : BaseKernel {
KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
KernelGaussianBlur(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
@ -47,7 +46,7 @@ struct KernelGaussianBlur : BaseKernel {
sigma[0], sigma[1], cv::BORDER_DEFAULT);
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
0, input_data_dimensions.data(), input_data_dimensions.size());
0, input_data_dimensions.data(), input_data_dimensions.size());
float* p_output_image = ort_.GetTensorMutableData<float>(image_y);
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
}
@ -76,7 +75,7 @@ struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGau
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
const char* GetName() const{
const char* GetName() const {
return "GaussianBlur";
}
};

Просмотреть файл

@ -6,9 +6,8 @@
#include <dlib/matrix.h>
#include "ocos.h"
struct KernelInverse : BaseKernel {
KernelInverse(const OrtApi& api) : BaseKernel(api) {
KernelInverse(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
@ -23,7 +22,7 @@ struct KernelInverse : BaseKernel {
}
OrtValue* output0 = ort_.KernelContext_GetOutput(
context, 0, dimensions.data(), dimensions.size());
context, 0, dimensions.data(), dimensions.size());
float* out0 = ort_.GetTensorMutableData<float>(output0);
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);

Просмотреть файл

@ -6,10 +6,10 @@
#include "ocos.h"
struct KernelNegPos : BaseKernel {
KernelNegPos(const OrtApi& api) : BaseKernel(api) {
KernelNegPos(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context){
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
@ -40,22 +40,22 @@ struct KernelNegPos : BaseKernel {
};
struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
const char* GetName() const{
const char* GetName() const {
return "NegPos";
}
size_t GetInputTypeCount() const{
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const{
size_t GetOutputTypeCount() const {
return 2;
}
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

Просмотреть файл

@ -3,7 +3,8 @@
#include "segment_extraction.hpp"
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
@ -11,7 +12,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
OrtTensorDimensions input_dim(ort_, input);
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
std::vector<std::int64_t> segment_value;
@ -35,8 +36,8 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
SetOutput(context, 0, segment_position_dim, segment_position);
SetOutput(context, 1, segment_value_dim, segment_value);
SetOutput(context, 0, segment_position_dim, segment_position);
SetOutput(context, 1, segment_value_dim, segment_value);
}
size_t CustomOpSegmentExtraction::GetInputTypeCount() const {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelSegmentExtraction : BaseKernel {
KernelSegmentExtraction(const OrtApi& api);
KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -20,8 +20,9 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
if (dim_data[0] != dim_seg[0])
ORTX_CXX_API_THROW(MakeString(
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
" segment_ids shape: ", dim_seg),
ORT_INVALID_GRAPH);
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
OrtTensorDimensions dim_out = dim_data;
@ -43,8 +44,9 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
for (; begin != end; ++p_seg) {
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
*(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
*(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), "."),
ORT_RUNTIME_EXCEPTION);
p_out = p_output + *p_seg * in_stride;
p_out_end = p_out + in_stride;
for (; p_out != p_out_end; ++p_out, ++begin)
@ -52,7 +54,7 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
}
}
KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
KernelSegmentSum::KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelSegmentSum::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelSegmentSum : BaseKernel {
KernelSegmentSum(const OrtApi& api);
KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -3,14 +3,11 @@
#include <sstream>
#include "ocos.h"
bool BaseKernel::HasAttribute(const char* name) const {
if (info_ == nullptr) {
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
bool BaseKernel::HasAttribute(const char* name) const noexcept {
size_t size;
std::string out;
// Crashes here.
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
auto r = api_.GetErrorCode(status);
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
if (has) {
@ -26,7 +23,7 @@ bool BaseKernel::HasAttribute(const char* name) const {
return true;
}
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
if (status == nullptr) {
return ORT_OK;
}
@ -35,22 +32,19 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
return error_code;
}
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
for (size_t i = 0; i < data.size(); i++) {
data_ptr[i] = data[i];
}
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
const std::vector<int64_t>& data) {
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
int64_t* data_ptr = ort_.GetTensorMutableData<int64_t>(output);
for (size_t i = 0; i < data.size(); i++) {
data_ptr[i] = data[i];
}
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
if (info_ == nullptr) {
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) const noexcept {
size_t size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
// The status should be a nullptr when querying for the size.
if (status != nullptr) {
@ -59,7 +53,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
}
value.resize(size);
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
status = api_.KernelInfoGetAttribute_string(&info_, name, &value[0], &size);
if (GetErrorCodeAndRelease(status) != ORT_OK) {
return false;
}
@ -69,31 +63,19 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
if (info_ == nullptr) {
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) const noexcept {
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &value)) == ORT_OK;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
if (info_ == nullptr) {
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcept {
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
if (info_ == nullptr) {
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
int64_t origin_value = 0;
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
return false;
}

Просмотреть файл

@ -8,8 +8,7 @@
#include <codecvt>
#include <algorithm>
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelMaskedFill::Compute(OrtKernelContext* context) {
@ -29,7 +28,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
}
std::vector<std::string> value;
const bool * mask = nullptr;
const bool* mask = nullptr;
GetTensorMutableDataString(api_, ort_, context, input_value, value);
mask = ort_.GetTensorData<bool>(input_mask);
@ -51,10 +50,6 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; };
size_t CustomOpMaskedFill::GetInputTypeCount() const {
@ -69,7 +64,8 @@ ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}};
}
};
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
return 1;

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map>
struct KernelMaskedFill : BaseKernel {
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -5,7 +5,7 @@
#include "op_equal_impl.hpp"
#include <string>
KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringEqual::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringEqual : BaseKernel {
KernelStringEqual(const OrtApi& api);
KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -12,7 +12,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
if (d_length.size() != 1)
ORTX_CXX_API_THROW(MakeString(
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
"First input must have one dimension not ", d_length.size(), "."),
ORT_INVALID_ARGUMENT);
int64_t n_els = d_length[0] - 1;
int64_t n_values = p_n_elements[n_els];
std::vector<int64_t> shape{n_values, 2};
@ -58,7 +59,8 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
@ -77,8 +79,9 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
return max_col;
}
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
: CommonRaggedTensorToDense(api, info) {
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(&info, "missing_value") : -1;
}
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
@ -104,8 +107,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
pos_end = pos + max_col;
if (pos_end > shape_out_size)
ORTX_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."),
ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
dense[pos] = p_values[j];
}
@ -127,10 +131,6 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpRaggedTensorToDense::GetName() const {
return "RaggedTensorToDense";
};
@ -139,7 +139,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensorToDense(api, info) {
}
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
@ -162,8 +162,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
pos_end = pos + max_col;
if (pos_end > shape_out_size)
ORTX_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."),
ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
dense[static_cast<size_t>(pos)] = input[static_cast<size_t>(j)];
}
@ -186,10 +187,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringRaggedTensorToDense::GetName() const {
return "StringRaggedTensorToDense";
};

Просмотреть файл

@ -6,8 +6,8 @@
#include "ocos.h"
struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(const OrtApi& api)
: BaseKernel(api) {}
KernelRaggedTensorToSparse(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {}
void Compute(OrtKernelContext* context);
};
@ -21,7 +21,7 @@ struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToS
};
struct CommonRaggedTensorToDense : BaseKernel {
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
protected:
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
@ -29,7 +29,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
};
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -41,20 +41,19 @@ struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDe
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> {
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense,
KernelStringRaggedTensorToDense> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};

Просмотреть файл

@ -8,8 +8,9 @@
#include "re2/re2.h"
#include "string_tensor.h"
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(&info_, "global_replace") : 1;
}
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
@ -28,12 +29,14 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
"pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."),
ORT_INVALID_ARGUMENT);
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
"rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."),
ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
@ -62,10 +65,6 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output);
}
void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
size_t CustomOpStringRegexReplace::GetInputTypeCount() const {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringRegexReplace : BaseKernel {
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
protected:
@ -15,7 +15,6 @@ struct KernelStringRegexReplace : BaseKernel {
};
struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,8 @@
#include <vector>
#include <cmath>
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
@ -24,13 +25,13 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1)
ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
str_pattern.size(), " values."),
ORT_INVALID_ARGUMENT);
if (str_keep_pattern.size() > 1)
ORTX_CXX_API_THROW(MakeString(
"Third input must contain only one element. It has ",
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
str_keep_pattern.size(), " values."),
ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
@ -50,8 +51,8 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
std::vector<int64_t> begin_offsets;
std::vector<int64_t> end_offsets;
RegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
include_delimiter, keep_reg,
tokens, begin_offsets, end_offsets);
include_delimiter, keep_reg,
tokens, begin_offsets, end_offsets);
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
for (size_t j = 0; j < begin_offsets.size(); ++j) {
all_begin_offsets.push_back(begin_offsets[j]);
@ -79,10 +80,6 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; };
size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const {
@ -106,7 +103,7 @@ ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(siz
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -8,12 +8,11 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringRegexSplitWithOffsets : BaseKernel {
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -8,8 +8,7 @@
#include <codecvt>
#include <algorithm>
KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
KernelStringConcat::KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringConcat::Compute(OrtKernelContext* context) {
@ -53,4 +52,4 @@ size_t CustomOpStringConcat::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringConcat : BaseKernel {
KernelStringConcat(const OrtApi& api);
KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -7,10 +7,10 @@
#include <regex>
#include "string_tensor.h"
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
}
void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
@ -24,19 +24,18 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
// Verifications
OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.Size() != 1) {
ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."),
ORT_INVALID_GRAPH);
}
if (rewrite_dimensions.Size() != 1) {
ORTX_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW(MakeString("rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."),
ORT_INVALID_GRAPH);
}
if (str_pattern[0].empty()) {
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
@ -70,10 +69,6 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output);
}
void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringECMARegexReplace::GetName() const { return "StringECMARegexReplace"; };
size_t CustomOpStringECMARegexReplace::GetInputTypeCount() const {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringECMARegexReplace : BaseKernel {
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
protected:
@ -16,7 +16,6 @@ struct KernelStringECMARegexReplace : BaseKernel {
};
struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,8 +9,9 @@
#include "string_ecmaregex_split.hpp"
#include "string_tensor.h"
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api,
const OrtKernelInfo& info)
: BaseKernel(api, info) {
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
}
@ -28,9 +29,13 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1)
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(),
" values."),
ORT_INVALID_GRAPH);
if (str_keep_pattern.size() > 1)
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(),
" values."),
ORT_INVALID_GRAPH);
if (str_pattern[0].empty())
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
@ -84,10 +89,6 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringECMARegexSplitWithOffsets::GetName() const { return "StringECMARegexSplitWithOffsets"; };
size_t CustomOpStringECMARegexSplitWithOffsets::GetInputTypeCount() const {
@ -112,7 +113,7 @@ ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT);
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -9,14 +9,14 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
bool ignore_case_;
};
struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
@ -48,7 +48,7 @@ void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
end_offsets.push_back(prev_pos + matched_length);
}
//no mather include the delimiter, we should skip it
// no mather include the delimiter, we should skip it
prev_pos += matched_length;
}

Просмотреть файл

@ -9,7 +9,7 @@
#include "string_hash.hpp"
KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
KernelStringHash::KernelStringHash(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringHash::Compute(OrtKernelContext* context) {
@ -68,7 +68,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
KernelStringHashFast::KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringHashFast::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringHash : BaseKernel {
KernelStringHash(const OrtApi& api);
KernelStringHash(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
@ -20,7 +20,7 @@ struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringH
};
struct KernelStringHashFast : BaseKernel {
KernelStringHashFast(const OrtApi& api);
KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_join.hpp"
#include "string_tensor.h"
KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
KernelStringJoin::KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringJoin::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringJoin : BaseKernel {
KernelStringJoin(const OrtApi& api);
KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
#include "ustring.h"
KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
KernelStringLength::KernelStringLength(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringLength::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringLength : BaseKernel {
KernelStringLength(const OrtApi& api);
KernelStringLength(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath>
#include <algorithm>
KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
KernelStringLower::KernelStringLower(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringLower::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringLower : BaseKernel {
KernelStringLower(const OrtApi& api);
KernelStringLower(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -8,11 +8,10 @@
#include <codecvt>
#include <algorithm>
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
auto lines = SplitString(map, "\n", true);
for (const auto& line: lines) {
for (const auto& line : lines) {
auto items = SplitString(line, "\t", true);
if (items.size() != 2) {
@ -30,7 +29,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
OrtTensorDimensions dimensions(ort_, input);
for (auto& str: input_data) {
for (auto& str : input_data) {
if (map_.find(str) != map_.end()) {
str = map_[str];
}
@ -41,10 +40,6 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, input_data, output);
}
void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringMapping::GetName() const { return "StringMapping"; };
size_t CustomOpStringMapping::GetInputTypeCount() const {

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map>
struct KernelStringMapping : BaseKernel {
KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_split.hpp"
#include "string_tensor.h"
KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
KernelStringSplit::KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringSplit : BaseKernel {
KernelStringSplit(const OrtApi& api);
KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -4,7 +4,6 @@
#include "string_to_vector.hpp"
#include "string_tensor.h"
StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
ParseMappingTable(map);
ParseUnkownValue(unk);
@ -12,7 +11,7 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
std::vector<std::vector<int64_t>> result;
// Set output dimension
output_dim = input_dim;
output_dim.push_back(vector_len_);
@ -100,10 +99,10 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
}
}
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unk");
impl_ = std::make_shared<StringToVectorImpl>(map, unk);
}
@ -132,10 +131,6 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
}
}
void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpStringToVector::GetName() const { return "StringToVector"; };
size_t CustomOpStringToVector::GetInputTypeCount() const {

Просмотреть файл

@ -9,7 +9,6 @@
#include "ocos.h"
#include "string_utils.h"
class StringToVectorImpl {
public:
StringToVectorImpl(std::string& map, std::string& unk);
@ -29,7 +28,7 @@ class StringToVectorImpl {
};
struct KernelStringToVector : BaseKernel {
KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -37,7 +36,6 @@ struct KernelStringToVector : BaseKernel {
};
struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath>
#include <algorithm>
KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
KernelStringUpper::KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringUpper::Compute(OrtKernelContext* context) {
@ -17,7 +17,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, input_X, X);
for (size_t i = 0; i < X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c){ return static_cast<char>(::toupper(c)); });
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(::toupper(c)); });
}
// Fills the output

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringUpper : BaseKernel {
KernelStringUpper(const OrtApi& api);
KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};

Просмотреть файл

@ -4,12 +4,11 @@
#include "vector_to_string.hpp"
#include "string_tensor.h"
namespace std {
template <class T>
size_t hash<std::vector<T>>::operator()(const vector<T>& __vector) const noexcept {
return util::Hash(reinterpret_cast<const char *>(__vector.data()), __vector.size() * sizeof(T));
return util::Hash(reinterpret_cast<const char*>(__vector.data()), __vector.size() * sizeof(T));
}
template struct hash<std::vector<std::string>>;
@ -38,7 +37,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
std::vector<int64_t> key(vector_len_);
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
//construct key
// construct key
for (size_t j = 0; j < vector_len_; j++) {
key[j] = ptr[j];
}
@ -103,9 +102,9 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
}
}
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unk");
// TODO: support more type when we can get input type from OrtKernelInfo
impl_ = std::make_shared<VectorToStringImpl>(map, unk);
@ -124,10 +123,6 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, mapping_result, output);
}
void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpVectorToString::GetName() const { return "VectorToString"; };
size_t CustomOpVectorToString::GetInputTypeCount() const {

Просмотреть файл

@ -30,10 +30,10 @@ class VectorToStringImpl {
std::unordered_map<std::vector<int64_t>, std::string> map_;
std::string unk_value_;
size_t vector_len_;
};
};
struct KernelVectorToString : BaseKernel {
KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -41,7 +41,6 @@ struct KernelVectorToString : BaseKernel {
};
struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,13 +9,13 @@
#include <codecvt>
#include <algorithm>
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
do_lower_case_(do_lower_case),
strip_accents_(strip_accents),
tokenize_chinese_chars_(tokenize_chinese_chars),
tokenize_punctuation_(tokenize_punctuation),
remove_control_chars_(remove_control_chars)
{}
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents,
bool tokenize_punctuation, bool remove_control_chars)
: do_lower_case_(do_lower_case),
strip_accents_(strip_accents),
tokenize_chinese_chars_(tokenize_chinese_chars),
tokenize_punctuation_(tokenize_punctuation),
remove_control_chars_(remove_control_chars) {}
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
std::vector<ustring> result;
@ -42,7 +42,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
if (do_lower_case_) {
for (auto& c : text) {
c = ToLower(c);
c = ToLower(c);
}
}
@ -57,7 +57,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
continue;
}
// 0x2019 unicode is not punctuation in some Linux platform,
// 0x2019 unicode is not punctuation in some Linux platform,
// to be consistent, take it as punctuation.
if (tokenize_punctuation_ && IsPunct(c)) {
push_current_token_and_clear();
@ -82,7 +82,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
return result;
}
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
@ -109,10 +109,6 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {

Просмотреть файл

@ -9,7 +9,8 @@
class BasicTokenizer {
public:
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars);
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation,
bool remove_control_chars);
std::vector<ustring> Tokenize(ustring text);
private:
@ -21,14 +22,14 @@ class BasicTokenizer {
};
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BasicTokenizer> tokenizer_;
};
struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -43,13 +43,10 @@ WordpieceTokenizer::WordpieceTokenizer(
std::shared_ptr<BertTokenizerVocab> vocab,
ustring unk_token,
ustring suffix_indicator,
int max_input_chars_per_word
) :
max_input_chars_per_word_(max_input_chars_per_word),
suffix_indicator_(std::move(suffix_indicator)),
unk_token_(std::move(unk_token)),
vocab_(std::move(vocab))
{
int max_input_chars_per_word) : max_input_chars_per_word_(max_input_chars_per_word),
suffix_indicator_(std::move(suffix_indicator)),
unk_token_(std::move(unk_token)),
vocab_(std::move(vocab)) {
unk_token_id_ = vocab_->FindTokenId(unk_token_);
}
@ -190,21 +187,17 @@ BertTokenizer::BertTokenizer(
bool strip_accents,
ustring suffix_indicator,
int32_t max_len,
const std::string& truncation_strategy
) :
max_length_(max_len),
do_basic_tokenize_(do_basic_tokenize),
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy))
{
const std::string& truncation_strategy) : max_length_(max_len),
do_basic_tokenize_(do_basic_tokenize),
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy)) {
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
if (do_basic_tokenize) {
basic_tokenizer_ = std::make_unique<BasicTokenizer>(
do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
}
wordpiece_tokenizer_ = std::make_unique<WordpieceTokenizer>(
vocab_, unk_token, suffix_indicator);
vocab_, unk_token, suffix_indicator);
unk_token_id_ = vocab_->FindTokenId(unk_token);
sep_token_id_ = vocab_->FindTokenId(sep_token);
@ -276,8 +269,8 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
}
}
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
@ -288,14 +281,15 @@ KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo*
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name",
std::string("longest_first"));
int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
tokenizer_ = std::make_unique<BertTokenizer>(
vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token), ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents,
ustring(suffix_indicator), max_len, truncation_strategy_name);
vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token), ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents,
ustring(suffix_indicator), max_len, truncation_strategy_name);
}
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
@ -334,10 +328,6 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, output_dim, attention_mask);
}
void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
}
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
@ -356,11 +346,12 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: KernelBertTokenizer(api, info) {}
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* const input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
@ -380,7 +371,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
const std::vector<int64_t> inner_dims{1LL};
for (int32_t i = 0; i < 3; ++i) {
OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
OrtTensorTypeAndShapeInfo* const info = ort_.GetTensorTypeAndShape(value);
ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
ort_.ReleaseTensorTypeAndShapeInfo(info);
}
@ -390,10 +381,6 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, outer_dims, token_type_ids);
}
void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
}
const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {

Просмотреть файл

@ -42,8 +42,8 @@ class TruncateStrategy final {
class WordpieceTokenizer final {
public:
WordpieceTokenizer(
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word = 100);
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word = 100);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
@ -90,7 +90,7 @@ class BertTokenizer final {
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
protected:
@ -98,7 +98,6 @@ struct KernelBertTokenizer : BaseKernel {
};
struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
@ -107,12 +106,11 @@ struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelB
};
struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,16 +7,13 @@ BertTokenizerDecoder::BertTokenizerDecoder(
std::string pad_token,
std::string cls_token,
std::string mask_token,
std::string suffix_indicator
) :
unk_token_(unk_token),
suffix_indicator_(suffix_indicator),
raw_vocab_(vocab)
{
std::string suffix_indicator) : unk_token_(unk_token),
suffix_indicator_(suffix_indicator),
raw_vocab_(vocab) {
auto tokens = SplitString(raw_vocab_, "\n", true);
vocab_.reserve(tokens.size());
for (size_t i = 0; i < tokens.size(); i++) {
auto& token = tokens[i];
auto& token = tokens[i];
if (token == unk_token) {
unk_token_id_ = static_cast<int32_t>(i);
}
@ -82,7 +79,6 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
}
bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id) {
if (pre_token_id < 0) {
return true;
}
@ -123,8 +119,8 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
return false;
}
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab_file");
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
@ -154,7 +150,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
OrtTensorDimensions positions_dim(ort_, positions);
if (use_indices_ &&
(!((positions_dim.Size() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
}
@ -181,10 +177,6 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {

Просмотреть файл

@ -10,11 +10,10 @@
#include "string_utils.h"
#include "string_tensor.h"
class BertTokenizerDecoder {
public:
BertTokenizerDecoder(std::string vocab, std::string unk_token, std::string sep_token, std::string pad_token,
std::string cls_token,std::string mask_token,std::string suffix_indicator);
std::string cls_token, std::string mask_token, std::string suffix_indicator);
std::string Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces);
private:
@ -33,8 +32,9 @@ class BertTokenizerDecoder {
};
struct KernelBertTokenizerDecoder : BaseKernel {
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizerDecoder> decoder_;
bool use_indices_;
@ -43,10 +43,9 @@ struct KernelBertTokenizerDecoder : BaseKernel {
};
struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
};

Просмотреть файл

@ -9,8 +9,9 @@
#include <algorithm>
#include <memory>
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
if (model_data_.empty()) {
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
@ -24,7 +25,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
if (HasAttribute("max_sentence")) {
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(&info, "max_sentence"));
}
}
@ -78,10 +79,6 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
}
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {

Просмотреть файл

@ -16,8 +16,9 @@ extern "C" int FreeModel(void* ModelPtr);
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
using ModelPtr = std::shared_ptr<void>;
ModelPtr model_;
@ -25,8 +26,8 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
int max_sentence;
};
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker,
KernelBlingFireSentenceBreaker> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -64,14 +64,14 @@ bool IsEmptyUstring(const ustring& str) {
return std::all_of(str.begin(), str.end(), [](char32_t ch) { return IsInUnicodeSpace(ch); });
}
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
if (vocab.empty()) {
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
if (merges.empty()) {
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
@ -202,10 +202,6 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
void* CustomOpClipBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
}
const char* CustomOpClipBpeTokenizer::GetName() const {
return "CLIPTokenizer";
}

Просмотреть файл

@ -6,7 +6,7 @@
class VocabData;
struct KernelClipBpeTokenizer : BaseKernel {
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -18,7 +18,6 @@ struct KernelClipBpeTokenizer : BaseKernel {
};
struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, KernelClipBpeTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -22,7 +22,7 @@
#include "string_tensor.h"
#include "unicode.h"
//Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
// Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
bool IsUnicodeSpace(char32_t ch) {
switch (ch) {
case 0x0009:
@ -63,14 +63,14 @@ bool IsEmptyUString(const ustring& str) {
return std::all_of(str.begin(), str.end(), [](char32_t ch) { return IsUnicodeSpace(ch); });
}
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
if (vocab.empty()) {
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
if (merges.empty()) {
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
@ -182,10 +182,6 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
}
const char* CustomOpBpeTokenizer::GetName() const {
return "GPT2Tokenizer";
}

Просмотреть файл

@ -2,11 +2,10 @@
#include "ocos.h"
#include "ustring.h"
class VocabData;
struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -18,7 +17,6 @@ struct KernelBpeTokenizer : BaseKernel {
};
struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -10,16 +10,15 @@
#include "sentencepiece_model.pb.h"
struct KernelSentencepieceDecoder : BaseKernel {
KernelSentencepieceDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string model_blob = ort_.KernelInfoGetAttribute<std::string>(info, "model");
KernelSentencepieceDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string model_blob = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
sentencepiece::ModelProto model_proto;
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
if (!status.ok()){
ORTX_CXX_API_THROW(MakeString(
"Failed to create SentencePieceProcessor instance. Error code is ",
(int)status.code(), ". Message is '", status.error_message(), "'."),
ORT_INVALID_PROTOBUF);
if (!status.ok()) {
ORTX_CXX_API_THROW(MakeString("Failed to create SentencePieceProcessor instance. Error code is ",
(int)status.code(), ". Message is '", status.error_message(), "'."),
ORT_INVALID_PROTOBUF);
}
}
@ -40,7 +39,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
std::back_inserter(tids),
[](auto _id) { return static_cast<int>(_id); });
auto status = tokenizer_.Decode(tids, &decoded_string);
if (!status.ok()){
if (!status.ok()) {
ORTX_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION);
}
@ -54,10 +53,6 @@ struct KernelSentencepieceDecoder : BaseKernel {
};
struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
}
const char* GetName() const {
return "SentencepieceDecoder";
}

Просмотреть файл

@ -7,8 +7,9 @@
#include "string_tensor.h"
#include "base64.h"
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes;
if (base64_decode(model_as_string, model_as_bytes)) {
@ -18,16 +19,16 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
}
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
if (!status.ok())
throw std::runtime_error(MakeString(
"Failed to create SentencePieceProcessor instance. Error code is ",
(int)status.code(), ". Message is '", status.error_message(), "'."));
ORTX_CXX_API_THROW(MakeString("Failed to create SentencePieceProcessor instance. Error code is ",
(int)status.code(), ". Message is '", status.error_message(), "'."),
ORT_FAIL);
}
static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort_value, const char* name) {
OrtTensorDimensions dimensions(ort, ort_value);
if (dimensions.size() != 1 || dimensions[0] != 1)
throw std::runtime_error(MakeString(
name, " must contain only one element. It has ", dimensions.size(), " dimensions."));
ORTX_CXX_API_THROW(MakeString(name, " must contain only one element. It has ", dimensions.size(), " dimensions."),
ORT_INVALID_ARGUMENT);
}
void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
@ -64,8 +65,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
for (size_t i = 0; i < str_input.size(); ++i) {
std::vector<int> inloop;
if (!tokenizer_.Encode(str_input[i].c_str(), &inloop).ok())
throw std::runtime_error(MakeString(
"Unable to encode string '", str_input[i], "'."));
ORTX_CXX_API_THROW(MakeString("Unable to encode string '", str_input[i], "'."), ORT_INVALID_ARGUMENT);
indices.push_back(content.size());
if (*p_add_rev) {
@ -103,10 +103,6 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
}
void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpSentencepieceTokenizer::GetName() const {
return "SentencepieceTokenizer";
};
@ -128,7 +124,7 @@ ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetInputType(size_t in
case 5:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
@ -143,6 +139,6 @@ ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetOutputType(size_t i
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("[SentencepieceTokenizer] Unexpected output index ", index));
ORTX_CXX_API_THROW(MakeString("[SentencepieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -7,17 +7,16 @@
#include "string_utils.h"
#include "sentencepiece_processor.h"
struct KernelSentencepieceTokenizer : BaseKernel {
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
sentencepiece::SentencePieceProcessor tokenizer_;
};
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer,
KernelSentencepieceTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,13 +4,16 @@
#include "wordpiece_tokenizer.hpp"
#include "nlohmann/json.hpp"
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(info, "suffix_indicator");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unknown_token");
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word") ? ort_.KernelInfoGetAttribute<int64_t>(info, "max_input_chars_per_word") : 200;
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(&info, "suffix_indicator");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unknown_token");
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word")
? ort_.KernelInfoGetAttribute<int64_t>(&info, "max_input_chars_per_word")
: 200;
suffix_indicator_ = ustring(suffix_indicator);
unk_token_ = ustring(unk);
@ -73,7 +76,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
} else if (text_index == existing_rows[row_index]) {
if (row_index >= n_existing_rows)
ORTX_CXX_API_THROW(MakeString(
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
"row_index=", row_index, " is out of range=", n_existing_rows, "."),
ORT_INVALID_ARGUMENT);
rows.push_back(indices.size());
++row_index;
}
@ -162,10 +166,6 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
}
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
};
const char* CustomOpWordpieceTokenizer::GetName() const {
return "WordpieceTokenizer";
};

Просмотреть файл

@ -11,7 +11,7 @@
#include "string_tensor.h"
struct KernelWordpieceTokenizer : BaseKernel {
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
@ -22,7 +22,6 @@ struct KernelWordpieceTokenizer : BaseKernel {
};
struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
@ -43,4 +42,4 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
std::vector<int64_t>& rows,
const int64_t* existing_rows = nullptr,
int64_t n_existing_rows = 0,
int64_t max_input_chars_per_word = 200);
int64_t max_input_chars_per_word = 200);

Просмотреть файл

@ -10,16 +10,12 @@
namespace ort_extensions {
struct KernelDecodeImage : BaseKernel {
KernelDecodeImage(const OrtApi& api) : BaseKernel(api) {}
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
void Compute(OrtKernelContext* context);
};
struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelDecodeImage(api);
}
void KernelDestroy(void* op_kernel) {
delete static_cast<KernelDecodeImage*>(op_kernel);
}

Просмотреть файл

@ -10,15 +10,20 @@
namespace ort_extensions {
struct KernelEncodeImage : BaseKernel {
KernelEncodeImage(const OrtApi& api, const std::string& format)
: BaseKernel{api},
extension_{std::string(".") + format} {
KernelEncodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel{api, info} {
OrtW::CustomOpApi op_api{api};
std::string format = op_api.KernelInfoGetAttribute<std::string>(&info, "format");
if (format != "jpg" && format != "png") {
ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
}
extension_ = std::string(".") + format;
}
void Compute(OrtKernelContext* context);
private:
const std::string extension_;
std::string extension_;
};
/// <summary>
@ -28,16 +33,6 @@ struct KernelEncodeImage : BaseKernel {
/// Default is 'jpg'
/// </summary>
struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
OrtW::CustomOpApi op_api{api};
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
if (format != "jpg" && format != "png") {
ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
}
return new KernelEncodeImage(api, format);
}
const char* GetName() const {
return "EncodeImage";
}

Просмотреть файл

@ -210,7 +210,7 @@ typedef struct {
std::vector<int64_t> dimensions;
} InputInformation;
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info,
uint64_t id, const std::vector<std::string>& attrs)
: api_(api),
ort_(api_),
@ -218,7 +218,7 @@ PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
size_t size;
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), nullptr, &size);
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
std::string error_message(api_.GetErrorMessage(status));
api_.ReleaseStatus(status);
@ -231,7 +231,7 @@ PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
}
attrs_values_[*it] = "";
attrs_values_[*it].resize(size);
status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), &(attrs_values_[*it][0]), &size);
if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
api_.ReleaseStatus(status);
throw std::runtime_error(MakeString(
@ -373,7 +373,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
op_domain = cod->op_type.substr(0, dm_pos);
op = cod->op_type.substr(dm_pos + 2, -1);
}
// No need to protect against concurrent access, GIL is doing that.
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
const auto [it_domain_op, success] = PyOp_container().insert(val);
@ -386,7 +386,7 @@ const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t num) {
EnablePyCustomOps(false);
return nullptr;
}
auto it = PyOp_container().find(c_OpDomain);
if (it != PyOp_container().end()) {
const std::vector<PyCustomOpFactory>& ref = it->second;
@ -414,7 +414,7 @@ bool EnablePyCustomOps(bool enabled) {
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions* options, const OrtApi* ortApi){
OrtCustomOpDomain* domain = nullptr;
OrtStatus* status = nullptr;
for (auto const& val_pair: PyOp_container()) {
if (val_pair.first == c_OpDomain) {
continue; // Register this domain in the second iteration.

Просмотреть файл

@ -37,7 +37,7 @@ struct PyCustomOpDef {
};
struct PyCustomOpKernel {
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info, uint64_t id, const std::vector<std::string>& attrs);
void Compute(OrtKernelContext* context);
private:
@ -49,7 +49,7 @@ struct PyCustomOpKernel {
struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
PyCustomOpFactory() {
// STL vector needs it.
// STL vector needs it.
}
PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
@ -60,8 +60,8 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
op_type_ = op;
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info, opdef_->obj_id, opdef_->attrs);
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
};
const char* GetName() const {

Просмотреть файл

@ -60,15 +60,19 @@ class ExternalCustomOps {
};
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
OCOS_API_IMPL_BEGIN
ExternalCustomOps::instance().Add(c_op);
OCOS_API_IMPL_END
return true;
}
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtStatus* status = nullptr;
OCOS_API_IMPL_BEGIN
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
std::set<std::string> pyop_nameset;
OrtStatus* status = nullptr;
#if defined(PYTHON_OP_SUPPORT)
if (status = RegisterPythonDomainAndOps(options, ortApi); status) {
@ -170,5 +174,9 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
}
}
return ortApi->AddCustomOpDomain(options, domain);
status = ortApi->AddCustomOpDomain(options, domain);
OCOS_API_IMPL_END
return status;
}

Двоичные данные
test/data/exceptional_custom_op1.onnx Normal file

Двоичный файл не отображается.

Двоичные данные
test/data/exceptional_custom_op2.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,120 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "ocos.h"
#include "test_kernel.hpp"
// throw in ctor which will be called during model load
struct ExceptionalKernel1 : BaseKernel {
ExceptionalKernel1(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
ORTX_CXX_API_THROW("Throw in ctor", ORT_FAIL);
}
void Compute(OrtKernelContext* context) {}
};
// throw in Compute which will be called during model execution
struct ExceptionalKernel2 : BaseKernel {
ExceptionalKernel2(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
ORTX_CXX_API_THROW("Throw in Compute", ORT_FAIL);
}
};
struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
const char* GetName() const { return "ExceptionalCustomOp1"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
};
struct ExceptionalCustomOp2 : OrtW::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
const char* GetName() const { return "ExceptionalCustomOp2"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
};
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
static ExceptionalCustomOp1 custom_op1;
static ExceptionalCustomOp2 custom_op2;
// test a call to an entry point wrapped with OCOS_API_IMPL_BEGIN/OCOS_API_IMPL_END behaves as expected.
// the throw in the ctor of ExceptionalCustomOp1 should be triggered during model loading.
TEST(Exceptions, TestApiTryCatch_ThrowInModelLoad) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
AddExternalCustomOp(&custom_op1);
Ort::SessionOptions session_options;
RegisterCustomOps((OrtSessionOptions*)session_options, OrtGetApiBase());
std::filesystem::path model("data/exceptional_custom_op1.onnx");
auto fail_fn = [&]() {
Ort::Session session(*ort_env, model.c_str(), session_options);
};
// if no exceptions, the ORTX_CXX_API_THROW will trigger the log+abort
// if no exception propagation, the OCOS_API_IMPL_END will trigger the log+abort
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
// the exception should be caught and logged, and the process should abort so the exception is not propagated up.
// log output needs to be manually checked
// can test on Linux but not Windows.
#if !defined(_WIN32)
EXPECT_EXIT(fail_fn(), ::testing::KilledBySignal(SIGABRT), ".*");
#endif
#else
// ORT catches the exceptions thrown by the custom op and rethrows them as Ort::Exception
EXPECT_THROW(fail_fn(), Ort::Exception);
#endif
}
// test a call to an entry point wrapped with OCOS_API_IMPL_BEGIN/OCOS_API_IMPL_END behaves as expected.
// the throw in the Compute of ExceptionalCustomOp2 should be triggered during model execution.
TEST(Exceptions, TestApiTryCatch_ThrowInModelExecution) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
AddExternalCustomOp(&custom_op2);
Ort::SessionOptions session_options;
RegisterCustomOps((OrtSessionOptions*)session_options, OrtGetApiBase());
std::filesystem::path model("data/exceptional_custom_op2.onnx");
Ort::Session session(*ort_env, model.c_str(), session_options);
Ort::AllocatorWithDefaultOptions allocator;
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
const char* input_names[] = {"A"};
const char* output_names[] = {"B"};
std::vector<int64_t> dims = {2};
std::vector<float> input = {0.f, 1.f};
std::vector<Ort::Value> ort_input;
ort_input.push_back(Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(),
dims.data(), dims.size()));
auto fail_fn = [&]() {
// executing the model should call Compute of the custom op, which should throw
std::vector<Ort::Value> ort_outputs;
ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, ort_input.data(), ort_input.size(),
output_names, 1);
};
// if no exceptions, the ORTX_CXX_API_THROW will trigger the log+abort
// if no exception propagation, the OCOS_API_IMPL_END will trigger the log+abort
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
// can test on Linux but not Windows
#if !defined(_WIN32)
EXPECT_EXIT(fail_fn(), ::testing::KilledBySignal(SIGABRT), ".*");
#endif
#else
// ORT catches the exceptions thrown by the custom op and rethrows them as Ort::Exception
EXPECT_THROW(fail_fn(), Ort::Exception);
#endif
}

Просмотреть файл

@ -1,40 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include <fstream>
#include <vector>
#include "gtest/gtest.h"
#include "exceptions.h"
#define TEST_MAIN main
#if defined(__APPLE__)
#include <TargetConditionals.h>
#if TARGET_OS_SIMULATOR || TARGET_OS_IOS
#undef TEST_MAIN
#define TEST_MAIN main_no_link_ // there is a UI test app for iOS.
#endif
#include <TargetConditionals.h>
#if TARGET_OS_SIMULATOR || TARGET_OS_IOS
#undef TEST_MAIN
#define TEST_MAIN main_no_link_ // there is a UI test app for iOS.
#endif
// currently this is the only place with a try/catch. Move the macros to common code if that changes.
#ifdef OCOS_NO_EXCEPTIONS
#define OCOS_TRY if (true)
#define OCOS_CATCH(x) else if (false)
#define OCOS_RETHROW
// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
// in the body of the catch statements, it is necessary to wrap the body of the catch statement into
// a lambda function. otherwise the exception referred will be undefined and cause build break
#define OCOS_HANDLE_EXCEPTION(func)
#else
#define OCOS_TRY try
#define OCOS_CATCH(x) catch (x)
#define OCOS_RETHROW throw;
#define OCOS_HANDLE_EXCEPTION(func) func()
#endif
namespace {
void FixCurrentDir() {
// adjust for the Google Test Adapter in Visual Studio not setting the current path to $(ProjectDir),

Просмотреть файл

@ -21,7 +21,7 @@ const char* GetLibraryPath() {
}
struct KernelOne : BaseKernel {
KernelOne(const OrtApi& api) : BaseKernel(api) {
KernelOne(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
@ -67,7 +67,7 @@ struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
};
struct KernelTwo : BaseKernel {
KernelTwo(const OrtApi& api) : BaseKernel(api) {
KernelTwo(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
@ -110,7 +110,7 @@ struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
};
struct KernelThree : BaseKernel {
KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelThree(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
if (!TryToGetAttribute("substr", substr_)) {
substr_ = "";
}
@ -137,8 +137,12 @@ struct KernelThree : BaseKernel {
};
struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info);
// This is example code to show how to override the CustomOpBase::CreateKernel method even though it is not virtual.
// The CustomOpBase implementation will call the CreateKernel of the first class specified in the template,
// and from there it's also possible to call the base CreateKernel as per below.
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
std::cout << "Called CreateKernel override" << std::endl;
return OrtW::CustomOpBase<CustomOpThree, KernelThree>::CreateKernel(api, info);
};
const char* GetName() const {
return "CustomOpThree";
@ -189,11 +193,11 @@ void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::v
OrtTensorDimensions dimensions(OrtW::CustomOpApi(api), value);
size_t len = static_cast<size_t>(dimensions.Size());
size_t data_len;
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
output.resize(len);
std::vector<char> result(data_len + len + 1, '\0');
std::vector<size_t> offsets(len);
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
output.resize(len);
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
if (i < len - 1)