diff --git a/.gitignore b/.gitignore index 36ef4dbb..130a530f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build_host_protoc build_android build_ios build_* +_subbuild/ .build_debug/* .build_release/* distribute/* diff --git a/CMakeLists.txt b/CMakeLists.txt index a783f43b..ee4be1f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,9 @@ option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" function(disable_all_operators) + set(OCOS_ENABLE_RE2_REGEX OFF CACHE INTERNAL "") set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "") + set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "") set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "") set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "") set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "") @@ -122,6 +124,11 @@ if (OCOS_ENABLE_TF_STRING) list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH}) endif() +if (OCOS_ENABLE_RE2_REGEX) + file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*") + list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS}) +endif() + if (OCOS_ENABLE_MATH) set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "") set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "") diff --git a/includes/ocos.h b/includes/ocos.h index 6f0f7c36..a396beaa 100644 --- a/includes/ocos.h +++ b/includes/ocos.h @@ -59,9 +59,6 @@ struct OrtTensorDimensions : std::vector { }; -struct CustomOpClassBegin{ -}; - template class CuopContainer { public: @@ -86,6 +83,9 @@ class CuopContainer { std::vector ocos_list_; }; +struct CustomOpClassBegin{ +}; + typedef std::function FxLoadCustomOpFactory; template @@ -99,3 +99,15 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count); OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*); bool EnablePyCustomOps(bool enable = true); #endif + +#ifdef ENABLE_MATH +extern FxLoadCustomOpFactory LoadCustomOpClasses_Math; +#endif // ENABLE_MATH + +#ifdef ENABLE_TOKENIZER +extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer; +#endif // ENABLE_TOKENIZER + +#ifdef ENABLE_TF_STRING +extern FxLoadCustomOpFactory LoadCustomOpClasses_Text; +#endif // ENABLE_TF_STRING diff --git a/operators/math/math.cc b/operators/math/math.cc index f9eeec9e..6ca2896e 100644 --- a/operators/math/math.cc +++ b/operators/math/math.cc @@ -1,6 +1,13 @@ #include "ocos.h" #include "negpos.hpp" #include "inverse.hpp" +#include "segment_extraction.hpp" +#include "segment_sum.hpp" -FxLoadCustomOpFactory LoadCustomOpClasses_Math = LoadCustomOpClasses; +FxLoadCustomOpFactory LoadCustomOpClasses_Math = + LoadCustomOpClasses; diff --git a/operators/text/op_segement_extraction.cc b/operators/math/segement_extraction.cc similarity index 98% rename from operators/text/op_segement_extraction.cc rename to operators/math/segement_extraction.cc index 3ae086da..461168bd 100644 --- a/operators/text/op_segement_extraction.cc +++ b/operators/math/segement_extraction.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "op_segment_extraction.hpp" +#include "segment_extraction.hpp" KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) { } diff --git a/operators/text/op_segment_extraction.hpp b/operators/math/segment_extraction.hpp similarity index 96% rename from operators/text/op_segment_extraction.hpp rename to operators/math/segment_extraction.hpp index bcdb9c3e..9425b33d 100644 --- a/operators/text/op_segment_extraction.hpp +++ b/operators/math/segment_extraction.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelSegmentExtraction : BaseKernel { diff --git a/operators/text/op_segment_sum.cc b/operators/math/segment_sum.cc similarity index 99% rename from operators/text/op_segment_sum.cc rename to operators/math/segment_sum.cc index 79ffd72a..b81e5126 100644 --- a/operators/text/op_segment_sum.cc +++ b/operators/math/segment_sum.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "op_segment_sum.hpp" +#include "segment_sum.hpp" template void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) { diff --git a/operators/text/op_segment_sum.hpp b/operators/math/segment_sum.hpp similarity index 96% rename from operators/text/op_segment_sum.hpp rename to operators/math/segment_sum.hpp index b985ffe1..c3f4505a 100644 --- a/operators/text/op_segment_sum.hpp +++ b/operators/math/segment_sum.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelSegmentSum : BaseKernel { diff --git a/operators/text/kernels.h b/operators/text/kernels.h deleted file mode 100644 index 7168bac7..00000000 --- a/operators/text/kernels.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "ocos.h" - - -// TO BE DELETED. diff --git a/operators/text/op_equal.hpp b/operators/text/op_equal.hpp index 0bed3107..18a4b261 100644 --- a/operators/text/op_equal.hpp +++ b/operators/text/op_equal.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringEqual : BaseKernel { diff --git a/operators/text/op_equal_impl.hpp b/operators/text/op_equal_impl.hpp index de97a8a3..9186cf46 100644 --- a/operators/text/op_equal_impl.hpp +++ b/operators/text/op_equal_impl.hpp @@ -13,7 +13,7 @@ class BroadcastIteratorRight { const std::vector& shape2, const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) { if (shape2.size() > shape1.size()) - throw std::runtime_error("shape2 must have less dimensions than shape1"); + ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT); shape2_.resize(shape1_.size()); cum_shape2_.resize(shape1_.size()); total_ = 1; @@ -26,8 +26,8 @@ class BroadcastIteratorRight { shape2_[i] = shape2[i]; } if (shape2[i] != 1 && shape1[i] != shape2[i]) { - throw std::runtime_error(MakeString( - "Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i])); + ORT_CXX_API_THROW(MakeString( + "Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT); } } cum_shape2_[shape2_.size() - 1] = 1; @@ -84,7 +84,7 @@ class BroadcastIteratorRight { template void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) { if (pos != 0) - throw std::runtime_error("Not implemented yet."); + ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED); while (!end()) { *p3 = cmp(*p1, *p2); next(); diff --git a/operators/text/op_ragged_tensor.cc b/operators/text/op_ragged_tensor.cc index c8c7bded..7e312548 100644 --- a/operators/text/op_ragged_tensor.cc +++ b/operators/text/op_ragged_tensor.cc @@ -14,8 +14,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) { OrtTensorDimensions d_length(ort_, n_elements); if (d_length.size() != 1) - throw std::runtime_error(MakeString( - "First input must have one dimension not ", d_length.size(), ".")); + ORT_CXX_API_THROW(MakeString( + "First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT); int64_t n_els = d_length[0] - 1; int64_t n_values = p_n_elements[n_els]; std::vector shape{n_values, 2}; @@ -110,9 +110,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) { for (int64_t i = 0; i < size - 1; ++i) { pos_end = pos + max_col; if (pos_end > shape_out_size) - throw std::runtime_error(MakeString( + ORT_CXX_API_THROW(MakeString( "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1], - " - i=", i, " size=", size, ".")); + " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT); for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) { dense[pos] = p_values[j]; } @@ -168,9 +168,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) { for (int64_t i = 0; i < size - 1; ++i) { pos_end = pos + max_col; if (pos_end > shape_out_size) - throw std::runtime_error(MakeString( + ORT_CXX_API_THROW(MakeString( "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1], - " - i=", i, " size=", size, ".")); + " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT); for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) { dense[pos] = input[j]; } @@ -210,6 +210,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t case 3: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, ".")); + ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT); } }; diff --git a/operators/text/op_ragged_tensor.hpp b/operators/text/op_ragged_tensor.hpp index 7fb3c8dd..eb0241b9 100644 --- a/operators/text/op_ragged_tensor.hpp +++ b/operators/text/op_ragged_tensor.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" struct KernelRaggedTensorToSparse : BaseKernel { KernelRaggedTensorToSparse(OrtApi api); diff --git a/operators/text/string_regex_replace.cc b/operators/text/re2_strings/string_regex_replace.cc similarity index 100% rename from operators/text/string_regex_replace.cc rename to operators/text/re2_strings/string_regex_replace.cc diff --git a/operators/text/string_regex_replace.hpp b/operators/text/re2_strings/string_regex_replace.hpp similarity index 97% rename from operators/text/string_regex_replace.hpp rename to operators/text/re2_strings/string_regex_replace.hpp index 2fb63fb6..6d3cb982 100644 --- a/operators/text/string_regex_replace.hpp +++ b/operators/text/re2_strings/string_regex_replace.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringRegexReplace : BaseKernel { diff --git a/operators/text/string_regex_split.cc b/operators/text/re2_strings/string_regex_split.cc similarity index 100% rename from operators/text/string_regex_split.cc rename to operators/text/re2_strings/string_regex_split.cc index 91650306..544a8cc8 100644 --- a/operators/text/string_regex_split.cc +++ b/operators/text/re2_strings/string_regex_split.cc @@ -3,9 +3,9 @@ #include "string_regex_split.hpp" #include "string_regex_split_re.hpp" +#include "string_tensor.h" #include #include -#include "string_tensor.h" KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { } diff --git a/operators/text/string_regex_split.hpp b/operators/text/re2_strings/string_regex_split.hpp similarity index 97% rename from operators/text/string_regex_split.hpp rename to operators/text/re2_strings/string_regex_split.hpp index bc182729..9d54ada2 100644 --- a/operators/text/string_regex_split.hpp +++ b/operators/text/re2_strings/string_regex_split.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" // See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md. diff --git a/operators/text/string_regex_split_re.hpp b/operators/text/re2_strings/string_regex_split_re.hpp similarity index 100% rename from operators/text/string_regex_split_re.hpp rename to operators/text/re2_strings/string_regex_split_re.hpp diff --git a/operators/text/string_concat.cc b/operators/text/string_concat.cc index 15699b0c..39093714 100644 --- a/operators/text/string_concat.cc +++ b/operators/text/string_concat.cc @@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) { OrtTensorDimensions right_dim(ort_, right); if (left_dim != right_dim) { - throw std::runtime_error("Two input tensor should have the same dimension."); + ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT); } std::vector left_value; diff --git a/operators/text/string_concat.hpp b/operators/text/string_concat.hpp index 2be1dcf9..a96fa73e 100644 --- a/operators/text/string_concat.hpp +++ b/operators/text/string_concat.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringConcat : BaseKernel { diff --git a/operators/text/string_ecmaregex_replace.hpp b/operators/text/string_ecmaregex_replace.hpp index 5e9b34e9..1098f1e3 100644 --- a/operators/text/string_ecmaregex_replace.hpp +++ b/operators/text/string_ecmaregex_replace.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringECMARegexReplace : BaseKernel { diff --git a/operators/text/string_ecmaregex_split.hpp b/operators/text/string_ecmaregex_split.hpp index 01a6c8ab..16fe8934 100644 --- a/operators/text/string_ecmaregex_split.hpp +++ b/operators/text/string_ecmaregex_split.hpp @@ -4,7 +4,7 @@ #pragma once #include -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" // See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md. diff --git a/operators/text/string_hash.cc b/operators/text/string_hash.cc index 9a6f7482..74f3e9c8 100644 --- a/operators/text/string_hash.cc +++ b/operators/text/string_hash.cc @@ -23,9 +23,9 @@ void KernelStringHash::Compute(OrtKernelContext* context) { // Verifications OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets); if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1) - throw std::runtime_error(MakeString( + ORT_CXX_API_THROW(MakeString( "num_buckets must contain only one element. It has ", - num_buckets_dimensions.size(), " dimensions.")); + num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); // Setup output OrtTensorDimensions dimensions(ort_, input); @@ -60,7 +60,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const { case 1: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("Unexpected input index ", index)); + ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); } }; @@ -86,9 +86,9 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) { // Verifications OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets); if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1) - throw std::runtime_error(MakeString( + ORT_CXX_API_THROW(MakeString( "num_buckets must contain only one element. It has ", - num_buckets_dimensions.size(), " dimensions.")); + num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); // Setup output OrtTensorDimensions dimensions(ort_, input); @@ -123,7 +123,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con case 1: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("Unexpected input index ", index)); + ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); } }; diff --git a/operators/text/string_hash.hpp b/operators/text/string_hash.hpp index 82c73be3..6bf87d78 100644 --- a/operators/text/string_hash.hpp +++ b/operators/text/string_hash.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringHash : BaseKernel { diff --git a/operators/text/string_join.cc b/operators/text/string_join.cc index 9b35e485..5f20a03c 100644 --- a/operators/text/string_join.cc +++ b/operators/text/string_join.cc @@ -20,13 +20,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) { // Setup output OrtTensorDimensions dimensions_sep(ort_, input_sep); if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1) - throw std::runtime_error("Input 2 is the separator, it has 1 element."); + ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT); OrtTensorDimensions dimensions_axis(ort_, input_axis); if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1) - throw std::runtime_error("Input 3 is the axis, it has 1 element."); + ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT); OrtTensorDimensions dimensions(ort_, input_X); if (*axis < 0 || *axis >= dimensions.size()) - throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis)); + ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT); std::vector dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1); if (dimensions.size() > 1) { @@ -89,7 +89,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const { case 2: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("Unexpected input index ", index)); + ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); } }; diff --git a/operators/text/string_join.hpp b/operators/text/string_join.hpp index abcc5a50..76511c77 100644 --- a/operators/text/string_join.hpp +++ b/operators/text/string_join.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringJoin : BaseKernel { diff --git a/operators/text/string_length.hpp b/operators/text/string_length.hpp index 0589abfc..f5b0822e 100644 --- a/operators/text/string_length.hpp +++ b/operators/text/string_length.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringLength : BaseKernel { diff --git a/operators/text/string_lower.hpp b/operators/text/string_lower.hpp index 496161fb..31b79210 100644 --- a/operators/text/string_lower.hpp +++ b/operators/text/string_lower.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringLower : BaseKernel { diff --git a/operators/text/string_split.cc b/operators/text/string_split.cc index fb8430c1..ce623a4a 100644 --- a/operators/text/string_split.cc +++ b/operators/text/string_split.cc @@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) { // Setup output OrtTensorDimensions dimensions_sep(ort_, input_sep); if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1) - throw std::runtime_error("Input 2 is the delimiter, it has 1 element."); + ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT); OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty); if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1) - throw std::runtime_error("Input 3 is skip_empty, it has 1 element."); + ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT); OrtTensorDimensions dimensions(ort_, input_X); if (dimensions.size() != 1) - throw std::runtime_error("Only 1D tensor are supported as input."); + ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT); std::vector words; std::vector indices; @@ -116,7 +116,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const case 2: return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; default: - throw std::runtime_error(MakeString("Unexpected input index ", index)); + ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); } }; @@ -132,6 +132,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const case 1: return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; default: - throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index)); + ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT); } }; diff --git a/operators/text/string_split.hpp b/operators/text/string_split.hpp index bc14a144..7edfc1ab 100644 --- a/operators/text/string_split.hpp +++ b/operators/text/string_split.hpp @@ -3,7 +3,7 @@ #pragma once -#include "kernels.h" +#include "ocos.h" #include "string_utils.h" struct KernelStringSplit : BaseKernel { diff --git a/operators/text/string_to_vector.cc b/operators/text/string_to_vector.cc index 4814e082..58e94586 100644 --- a/operators/text/string_to_vector.cc +++ b/operators/text/string_to_vector.cc @@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) { vector_len_ = ParseVectorLen(lines[0]); if (vector_len_ == 0) { - throw std::runtime_error(MakeString("The mapped value of string input cannot be empty: ", lines[0])); + ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT); } std::vector values(vector_len_); @@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) { auto kv = SplitString(line, "\t", true); if (kv.size() != 2) { - throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line)); + ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); } ParseValues(kv[1], values); @@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) { void StringToVectorImpl::ParseUnkownValue(std::string& unk) { auto unk_strs = SplitString(unk, " ", true); if (unk_strs.size() != vector_len_) { - throw std::runtime_error(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_)); + ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT); } for (auto& str : unk_strs) { int64_t value; auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value); if (end != str.data() + str.size()) { - throw std::runtime_error(MakeString("Failed to parse unknown_value when processing the number: ", str)); + ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT); } unk_value_.push_back(value); @@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) { auto kv = SplitString(line, "\t", true); if (kv.size() != 2) { - throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line)); + ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); } auto value_strs = SplitString(kv[1], " ", true); @@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector; diff --git a/operators/text/vector_to_string.cc b/operators/text/vector_to_string.cc index 6cc8696d..4725e920 100644 --- a/operators/text/vector_to_string.cc +++ b/operators/text/vector_to_string.cc @@ -29,7 +29,7 @@ std::vector VectorToStringImpl::Compute(const void* input, const Or output_dim = input_dim; } else { if (input_dim[input_dim.size() - 1] != vector_len_) { - throw std::runtime_error(MakeString("Incompatible dimension: required vector length should be ", vector_len_)); + ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT); } output_dim = input_dim; @@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) { auto kv = SplitString(line, "\t", true); if (kv.size() != 2) { - throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line)); + ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); } ParseValues(kv[1], values); @@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) { auto kv = SplitString(line, "\t", true); if (kv.size() != 2) { - throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line)); + ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); } auto value_strs = SplitString(kv[1], " ", true); @@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector(info, "model"); if (model_data_.empty()) { - throw std::runtime_error("vocabulary shouldn't be empty."); + ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); } void* model_ptr = SetModel(reinterpret_cast(model_data_.data()), model_data_.size()); if (model_ptr == nullptr) { - throw std::runtime_error("Invalid model"); + ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT); } model_ = std::shared_ptr(model_ptr, FreeModel); @@ -33,7 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) { OrtTensorDimensions dimensions(ort_, input); if (dimensions.Size() != 1 && dimensions[0] != 1) { - throw std::runtime_error("We only support string scalar."); + ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT); } std::vector input_data; @@ -46,7 +46,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) { int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get()); if (output_length < 0) { - throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed")); + ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT); } // inline split output_str by newline '\n' diff --git a/operators/tokenizer/wordpiece_tokenizer.cc b/operators/tokenizer/wordpiece_tokenizer.cc index 1e2bd2e9..ae0eba0e 100644 --- a/operators/tokenizer/wordpiece_tokenizer.cc +++ b/operators/tokenizer/wordpiece_tokenizer.cc @@ -72,8 +72,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map= n_existing_rows) - throw std::runtime_error(MakeString( - "row_index=", row_index, " is out of range=", n_existing_rows, ".")); + ORT_CXX_API_THROW(MakeString( + "row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT); rows.push_back(indices.size()); ++row_index; } @@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index) case 1: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("Unexpected input index ", index)); + ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); } }; @@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index case 3: return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; default: - throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index)); + ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT); } }; \ No newline at end of file diff --git a/shared/ortcustomops.cc b/shared/ortcustomops.cc index f77309ba..143b6b69 100644 --- a/shared/ortcustomops.cc +++ b/shared/ortcustomops.cc @@ -4,89 +4,7 @@ #include #include "onnxruntime_extensions.h" -#include "string_utils.h" - -#include "text/op_equal.hpp" -#include "text/op_segment_sum.hpp" -#include "text/op_segment_extraction.hpp" -#include "text/op_ragged_tensor.hpp" -#include "text/string_hash.hpp" -#include "text/string_join.hpp" -#include "text/string_lower.hpp" -#include "text/string_regex_replace.hpp" -#include "text/string_regex_split.hpp" -#include "text/string_split.hpp" -#include "text/string_to_vector.hpp" -#include "text/string_upper.hpp" -#include "text/vector_to_string.hpp" -#include "text/string_length.hpp" -#include "text/string_concat.hpp" -#include "text/string_ecmaregex_replace.hpp" -#include "text/string_ecmaregex_split.hpp" - - -#ifdef ENABLE_TF_STRING -CustomOpSegmentExtraction c_CustomOpSegmentExtraction; -CustomOpSegmentSum c_CustomOpSegmentSum; -CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense; -CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse; -CustomOpStringEqual c_CustomOpStringEqual; -CustomOpStringHash c_CustomOpStringHash; -CustomOpStringHashFast c_CustomOpStringHashFast; -CustomOpStringJoin c_CustomOpStringJoin; -CustomOpStringLower c_CustomOpStringLower; -CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense; -CustomOpStringECMARegexReplace c_CustomOpStringECMARegexReplace; -CustomOpStringECMARegexSplitWithOffsets c_CustomOpStringECMARegexSplitWithOffsets; -CustomOpStringSplit c_CustomOpStringSplit; -CustomOpStringToVector c_CustomOpStringToVector; -CustomOpStringUpper c_CustomOpStringUpper; -CustomOpVectorToString c_CustomOpVectorToString; -CustomOpStringLength c_CustomOpStringLength; -CustomOpStringConcat c_CustomOpStringConcat; -#endif - -#ifdef ENABLE_RE2_REGEX -CustomOpStringRegexReplace c_CustomOpStringRegexReplace; -CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets; -#endif - -OrtCustomOp* operator_lists[] = { -#ifdef ENABLE_TF_STRING - &c_CustomOpRaggedTensorToDense, - &c_CustomOpRaggedTensorToSparse, - &c_CustomOpSegmentSum, - &c_CustomOpSegmentExtraction, - &c_CustomOpStringEqual, - &c_CustomOpStringHash, - &c_CustomOpStringHashFast, - &c_CustomOpStringJoin, - &c_CustomOpStringLower, - &c_CustomOpStringRaggedTensorToDense, - &c_CustomOpStringECMARegexReplace, - &c_CustomOpStringECMARegexSplitWithOffsets, - &c_CustomOpStringSplit, - &c_CustomOpStringToVector, - &c_CustomOpStringUpper, - &c_CustomOpVectorToString, - &c_CustomOpStringLength, - &c_CustomOpStringConcat, -#endif - -#ifdef ENABLE_RE2_REGEX - &c_CustomOpStringRegexReplace, - &c_CustomOpStringRegexSplitWithOffsets, -#endif - - nullptr }; - -#ifdef ENABLE_MATH -extern FxLoadCustomOpFactory LoadCustomOpClasses_Math; -#endif // ENABLE_MATH - -#ifdef ENABLE_TOKENIZER -extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer; -#endif // ENABLE_TOKENIZER +#include "ocos.h" class ExternalCustomOps { @@ -154,7 +72,10 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, #endif static std::vector c_factories = { - []() { return const_cast(operator_lists); } + LoadCustomOpClasses +#if defined(ENABLE_TF_STRING) + , LoadCustomOpClasses_Text +#endif // ENABLE_TF_STRING #if defined(ENABLE_MATH) , LoadCustomOpClasses_Math #endif diff --git a/test/static_test/test_strings.cc b/test/static_test/test_strings.cc index 3ac72225..9805f7dc 100644 --- a/test/static_test/test_strings.cc +++ b/test/static_test/test_strings.cc @@ -3,7 +3,7 @@ #include "gtest/gtest.h" #include "string_utils.h" -#include "text/string_regex_split_re.hpp" +#include "text/re2_strings/string_regex_split_re.hpp" #include "text/string_ecmaregex_split.hpp"