support the non-exception compiling for the text domain. (#142)
* support the non-exception compiling for the text domain. * fix an path error.
This commit is contained in:
Родитель
97ec950751
Коммит
2842d2208e
|
@ -5,6 +5,7 @@ build_host_protoc
|
|||
build_android
|
||||
build_ios
|
||||
build_*
|
||||
_subbuild/
|
||||
.build_debug/*
|
||||
.build_release/*
|
||||
distribute/*
|
||||
|
|
|
@ -38,7 +38,9 @@ option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file"
|
|||
|
||||
|
||||
function(disable_all_operators)
|
||||
set(OCOS_ENABLE_RE2_REGEX OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "")
|
||||
|
@ -122,6 +124,11 @@ if (OCOS_ENABLE_TF_STRING)
|
|||
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_RE2_REGEX)
|
||||
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_MATH)
|
||||
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
|
||||
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
||||
|
|
|
@ -59,9 +59,6 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
};
|
||||
|
||||
|
||||
struct CustomOpClassBegin{
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
class CuopContainer {
|
||||
public:
|
||||
|
@ -86,6 +83,9 @@ class CuopContainer {
|
|||
std::vector<OrtCustomOp*> ocos_list_;
|
||||
};
|
||||
|
||||
struct CustomOpClassBegin{
|
||||
};
|
||||
|
||||
typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
|
||||
|
||||
template <typename _Begin_place_holder, typename... Args>
|
||||
|
@ -99,3 +99,15 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
|||
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
|
||||
bool EnablePyCustomOps(bool enable = true);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MATH
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
||||
#endif // ENABLE_MATH
|
||||
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
|
||||
#endif // ENABLE_TOKENIZER
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
|
||||
#endif // ENABLE_TF_STRING
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
#include "ocos.h"
|
||||
#include "negpos.hpp"
|
||||
#include "inverse.hpp"
|
||||
#include "segment_extraction.hpp"
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = LoadCustomOpClasses<CustomOpClassBegin, CustomOpNegPos, CustomOpInverse>;
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
CustomOpNegPos,
|
||||
CustomOpInverse,
|
||||
CustomOpSegmentExtraction,
|
||||
CustomOpSegmentSum>;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "op_segment_extraction.hpp"
|
||||
#include "segment_extraction.hpp"
|
||||
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
|
||||
}
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentExtraction : BaseKernel {
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "op_segment_sum.hpp"
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
template <typename T>
|
||||
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentSum : BaseKernel {
|
|
@ -1,9 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
// TO BE DELETED.
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringEqual : BaseKernel {
|
||||
|
|
|
@ -13,7 +13,7 @@ class BroadcastIteratorRight {
|
|||
const std::vector<int64_t>& shape2,
|
||||
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
|
||||
if (shape2.size() > shape1.size())
|
||||
throw std::runtime_error("shape2 must have less dimensions than shape1");
|
||||
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
|
||||
shape2_.resize(shape1_.size());
|
||||
cum_shape2_.resize(shape1_.size());
|
||||
total_ = 1;
|
||||
|
@ -26,8 +26,8 @@ class BroadcastIteratorRight {
|
|||
shape2_[i] = shape2[i];
|
||||
}
|
||||
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
||||
throw std::runtime_error(MakeString(
|
||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]));
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
cum_shape2_[shape2_.size() - 1] = 1;
|
||||
|
@ -84,7 +84,7 @@ class BroadcastIteratorRight {
|
|||
template <typename TCMP>
|
||||
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
|
||||
if (pos != 0)
|
||||
throw std::runtime_error("Not implemented yet.");
|
||||
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
|
||||
while (!end()) {
|
||||
*p3 = cmp(*p1, *p2);
|
||||
next();
|
||||
|
|
|
@ -14,8 +14,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions d_length(ort_, n_elements);
|
||||
|
||||
if (d_length.size() != 1)
|
||||
throw std::runtime_error(MakeString(
|
||||
"First input must have one dimension not ", d_length.size(), "."));
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
|
||||
int64_t n_els = d_length[0] - 1;
|
||||
int64_t n_values = p_n_elements[n_els];
|
||||
std::vector<int64_t> shape{n_values, 2};
|
||||
|
@ -110,9 +110,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
for (int64_t i = 0; i < size - 1; ++i) {
|
||||
pos_end = pos + max_col;
|
||||
if (pos_end > shape_out_size)
|
||||
throw std::runtime_error(MakeString(
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."));
|
||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||
dense[pos] = p_values[j];
|
||||
}
|
||||
|
@ -168,9 +168,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
for (int64_t i = 0; i < size - 1; ++i) {
|
||||
pos_end = pos + max_col;
|
||||
if (pos_end > shape_out_size)
|
||||
throw std::runtime_error(MakeString(
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."));
|
||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||
dense[pos] = input[j];
|
||||
}
|
||||
|
@ -210,6 +210,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t
|
|||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."));
|
||||
ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||
KernelRaggedTensorToSparse(OrtApi api);
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringRegexReplace : BaseKernel {
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
#include "string_regex_split.hpp"
|
||||
#include "string_regex_split_re.hpp"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
}
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
|
@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions right_dim(ort_, right);
|
||||
|
||||
if (left_dim != right_dim) {
|
||||
throw std::runtime_error("Two input tensor should have the same dimension.");
|
||||
ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<std::string> left_value;
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringConcat : BaseKernel {
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringECMARegexReplace : BaseKernel {
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
|
|
|
@ -23,9 +23,9 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
|||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||
throw std::runtime_error(MakeString(
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"num_buckets must contain only one element. It has ",
|
||||
num_buckets_dimensions.size(), " dimensions."));
|
||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
@ -60,7 +60,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
|
|||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -86,9 +86,9 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
|||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||
throw std::runtime_error(MakeString(
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"num_buckets must contain only one element. It has ",
|
||||
num_buckets_dimensions.size(), " dimensions."));
|
||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
@ -123,7 +123,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con
|
|||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringHash : BaseKernel {
|
||||
|
|
|
@ -20,13 +20,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
|
||||
ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||
throw std::runtime_error("Input 3 is the axis, it has 1 element.");
|
||||
ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (*axis < 0 || *axis >= dimensions.size())
|
||||
throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis));
|
||||
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||
|
||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||
if (dimensions.size() > 1) {
|
||||
|
@ -89,7 +89,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
|
|||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringJoin : BaseKernel {
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLength : BaseKernel {
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLower : BaseKernel {
|
||||
|
|
|
@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
||||
ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
||||
ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 1)
|
||||
throw std::runtime_error("Only 1D tensor are supported as input.");
|
||||
ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
std::vector<std::string> words;
|
||||
std::vector<int64_t> indices;
|
||||
|
@ -116,7 +116,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const
|
|||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -132,6 +132,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const
|
|||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
|
|
|
@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
|
||||
vector_len_ = ParseVectorLen(lines[0]);
|
||||
if (vector_len_ == 0) {
|
||||
throw std::runtime_error(MakeString("The mapped value of string input cannot be empty: ", lines[0]));
|
||||
ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<int64_t> values(vector_len_);
|
||||
|
@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
ParseValues(kv[1], values);
|
||||
|
@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
||||
auto unk_strs = SplitString(unk, " ", true);
|
||||
if (unk_strs.size() != vector_len_) {
|
||||
throw std::runtime_error(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_));
|
||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
for (auto& str : unk_strs) {
|
||||
int64_t value;
|
||||
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
||||
if (end != str.data() + str.size()) {
|
||||
throw std::runtime_error(MakeString("Failed to parse unknown_value when processing the number: ", str));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
unk_value_.push_back(value);
|
||||
|
@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto value_strs = SplitString(kv[1], " ", true);
|
||||
|
@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
for (int i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
values[i] = value;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
#include "text/op_equal.hpp"
|
||||
#include "text/op_ragged_tensor.hpp"
|
||||
#include "text/string_hash.hpp"
|
||||
#include "text/string_join.hpp"
|
||||
#include "text/string_lower.hpp"
|
||||
#include "text/string_split.hpp"
|
||||
#include "text/string_to_vector.hpp"
|
||||
#include "text/string_upper.hpp"
|
||||
#include "text/vector_to_string.hpp"
|
||||
#include "text/string_length.hpp"
|
||||
#include "text/string_concat.hpp"
|
||||
#include "text/string_ecmaregex_replace.hpp"
|
||||
#include "text/string_ecmaregex_split.hpp"
|
||||
|
||||
#if defined(ENABLE_RE2_REGEX)
|
||||
#include "text/re2_strings/string_regex_replace.hpp"
|
||||
#include "text/re2_strings/string_regex_split.hpp"
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
#if defined(ENABLE_RE2_REGEX)
|
||||
CustomOpStringRegexReplace,
|
||||
CustomOpStringRegexSplitWithOffsets,
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
CustomOpRaggedTensorToDense,
|
||||
CustomOpRaggedTensorToSparse,
|
||||
CustomOpStringRaggedTensorToDense,
|
||||
CustomOpStringEqual,
|
||||
CustomOpStringHash,
|
||||
CustomOpStringHashFast,
|
||||
CustomOpStringJoin,
|
||||
CustomOpStringLower,
|
||||
CustomOpStringUpper,
|
||||
CustomOpStringSplit,
|
||||
CustomOpStringToVector,
|
||||
CustomOpVectorToString,
|
||||
CustomOpStringLength,
|
||||
CustomOpStringConcat,
|
||||
CustomOpStringECMARegexReplace,
|
||||
CustomOpStringECMARegexSplitWithOffsets
|
||||
>;
|
|
@ -29,7 +29,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
output_dim = input_dim;
|
||||
} else {
|
||||
if (input_dim[input_dim.size() - 1] != vector_len_) {
|
||||
throw std::runtime_error(MakeString("Incompatible dimension: required vector length should be ", vector_len_));
|
||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
output_dim = input_dim;
|
||||
|
@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
ParseValues(kv[1], values);
|
||||
|
@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto value_strs = SplitString(kv[1], " ", true);
|
||||
|
@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
for (int i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
values[i] = value;
|
||||
}
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
if (model_data_.empty()) {
|
||||
throw std::runtime_error("vocabulary shouldn't be empty.");
|
||||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
|
||||
|
||||
if (model_ptr == nullptr) {
|
||||
throw std::runtime_error("Invalid model");
|
||||
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||
|
@ -33,7 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||
throw std::runtime_error("We only support string scalar.");
|
||||
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<std::string> input_data;
|
||||
|
@ -46,7 +46,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
|
||||
if (output_length < 0) {
|
||||
throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed"));
|
||||
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// inline split output_str by newline '\n'
|
||||
|
|
|
@ -72,8 +72,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
rows.push_back(indices.size());
|
||||
} else if (text_index == existing_rows[row_index]) {
|
||||
if (row_index >= n_existing_rows)
|
||||
throw std::runtime_error(MakeString(
|
||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."));
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
|
||||
rows.push_back(indices.size());
|
||||
++row_index;
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index)
|
|||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
|
|||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
|
||||
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
|
@ -4,89 +4,7 @@
|
|||
#include <set>
|
||||
|
||||
#include "onnxruntime_extensions.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include "text/op_equal.hpp"
|
||||
#include "text/op_segment_sum.hpp"
|
||||
#include "text/op_segment_extraction.hpp"
|
||||
#include "text/op_ragged_tensor.hpp"
|
||||
#include "text/string_hash.hpp"
|
||||
#include "text/string_join.hpp"
|
||||
#include "text/string_lower.hpp"
|
||||
#include "text/string_regex_replace.hpp"
|
||||
#include "text/string_regex_split.hpp"
|
||||
#include "text/string_split.hpp"
|
||||
#include "text/string_to_vector.hpp"
|
||||
#include "text/string_upper.hpp"
|
||||
#include "text/vector_to_string.hpp"
|
||||
#include "text/string_length.hpp"
|
||||
#include "text/string_concat.hpp"
|
||||
#include "text/string_ecmaregex_replace.hpp"
|
||||
#include "text/string_ecmaregex_split.hpp"
|
||||
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
CustomOpSegmentExtraction c_CustomOpSegmentExtraction;
|
||||
CustomOpSegmentSum c_CustomOpSegmentSum;
|
||||
CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
|
||||
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
|
||||
CustomOpStringEqual c_CustomOpStringEqual;
|
||||
CustomOpStringHash c_CustomOpStringHash;
|
||||
CustomOpStringHashFast c_CustomOpStringHashFast;
|
||||
CustomOpStringJoin c_CustomOpStringJoin;
|
||||
CustomOpStringLower c_CustomOpStringLower;
|
||||
CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense;
|
||||
CustomOpStringECMARegexReplace c_CustomOpStringECMARegexReplace;
|
||||
CustomOpStringECMARegexSplitWithOffsets c_CustomOpStringECMARegexSplitWithOffsets;
|
||||
CustomOpStringSplit c_CustomOpStringSplit;
|
||||
CustomOpStringToVector c_CustomOpStringToVector;
|
||||
CustomOpStringUpper c_CustomOpStringUpper;
|
||||
CustomOpVectorToString c_CustomOpVectorToString;
|
||||
CustomOpStringLength c_CustomOpStringLength;
|
||||
CustomOpStringConcat c_CustomOpStringConcat;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
||||
CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
|
||||
#endif
|
||||
|
||||
OrtCustomOp* operator_lists[] = {
|
||||
#ifdef ENABLE_TF_STRING
|
||||
&c_CustomOpRaggedTensorToDense,
|
||||
&c_CustomOpRaggedTensorToSparse,
|
||||
&c_CustomOpSegmentSum,
|
||||
&c_CustomOpSegmentExtraction,
|
||||
&c_CustomOpStringEqual,
|
||||
&c_CustomOpStringHash,
|
||||
&c_CustomOpStringHashFast,
|
||||
&c_CustomOpStringJoin,
|
||||
&c_CustomOpStringLower,
|
||||
&c_CustomOpStringRaggedTensorToDense,
|
||||
&c_CustomOpStringECMARegexReplace,
|
||||
&c_CustomOpStringECMARegexSplitWithOffsets,
|
||||
&c_CustomOpStringSplit,
|
||||
&c_CustomOpStringToVector,
|
||||
&c_CustomOpStringUpper,
|
||||
&c_CustomOpVectorToString,
|
||||
&c_CustomOpStringLength,
|
||||
&c_CustomOpStringConcat,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
&c_CustomOpStringRegexReplace,
|
||||
&c_CustomOpStringRegexSplitWithOffsets,
|
||||
#endif
|
||||
|
||||
nullptr };
|
||||
|
||||
#ifdef ENABLE_MATH
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
||||
#endif // ENABLE_MATH
|
||||
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
|
||||
#endif // ENABLE_TOKENIZER
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
class ExternalCustomOps {
|
||||
|
@ -154,7 +72,10 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
#endif
|
||||
|
||||
static std::vector<FxLoadCustomOpFactory> c_factories = {
|
||||
[]() { return const_cast<const OrtCustomOp**>(operator_lists); }
|
||||
LoadCustomOpClasses<CustomOpClassBegin>
|
||||
#if defined(ENABLE_TF_STRING)
|
||||
, LoadCustomOpClasses_Text
|
||||
#endif // ENABLE_TF_STRING
|
||||
#if defined(ENABLE_MATH)
|
||||
, LoadCustomOpClasses_Math
|
||||
#endif
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "string_utils.h"
|
||||
#include "text/string_regex_split_re.hpp"
|
||||
#include "text/re2_strings/string_regex_split_re.hpp"
|
||||
#include "text/string_ecmaregex_split.hpp"
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче