support the non-exception compiling for the text domain. (#142)

* support the non-exception compiling for the text domain.

* fix an path error.
This commit is contained in:
Wenbing Li 2021-09-02 11:19:18 -07:00 коммит произвёл GitHub
Родитель 97ec950751
Коммит 2842d2208e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
37 изменённых файлов: 142 добавлений и 160 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -5,6 +5,7 @@ build_host_protoc
build_android
build_ios
build_*
_subbuild/
.build_debug/*
.build_release/*
distribute/*

Просмотреть файл

@ -38,7 +38,9 @@ option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file"
function(disable_all_operators)
set(OCOS_ENABLE_RE2_REGEX OFF CACHE INTERNAL "")
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "")
set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "")
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "")
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "")
@ -122,6 +124,11 @@ if (OCOS_ENABLE_TF_STRING)
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
endif()
if (OCOS_ENABLE_RE2_REGEX)
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
endif()
if (OCOS_ENABLE_MATH)
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")

Просмотреть файл

@ -59,9 +59,6 @@ struct OrtTensorDimensions : std::vector<int64_t> {
};
struct CustomOpClassBegin{
};
template <typename... Args>
class CuopContainer {
public:
@ -86,6 +83,9 @@ class CuopContainer {
std::vector<OrtCustomOp*> ocos_list_;
};
struct CustomOpClassBegin{
};
typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
template <typename _Begin_place_holder, typename... Args>
@ -99,3 +99,15 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count);
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
bool EnablePyCustomOps(bool enable = true);
#endif
#ifdef ENABLE_MATH
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
#endif // ENABLE_MATH
#ifdef ENABLE_TOKENIZER
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
#endif // ENABLE_TOKENIZER
#ifdef ENABLE_TF_STRING
extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
#endif // ENABLE_TF_STRING

Просмотреть файл

@ -1,6 +1,13 @@
#include "ocos.h"
#include "negpos.hpp"
#include "inverse.hpp"
#include "segment_extraction.hpp"
#include "segment_sum.hpp"
FxLoadCustomOpFactory LoadCustomOpClasses_Math = LoadCustomOpClasses<CustomOpClassBegin, CustomOpNegPos, CustomOpInverse>;
FxLoadCustomOpFactory LoadCustomOpClasses_Math =
LoadCustomOpClasses<CustomOpClassBegin,
CustomOpNegPos,
CustomOpInverse,
CustomOpSegmentExtraction,
CustomOpSegmentSum>;

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "op_segment_extraction.hpp"
#include "segment_extraction.hpp"
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
}

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelSegmentExtraction : BaseKernel {

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "op_segment_sum.hpp"
#include "segment_sum.hpp"
template <typename T>
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelSegmentSum : BaseKernel {

Просмотреть файл

@ -1,9 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
// TO BE DELETED.

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringEqual : BaseKernel {

Просмотреть файл

@ -13,7 +13,7 @@ class BroadcastIteratorRight {
const std::vector<int64_t>& shape2,
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
if (shape2.size() > shape1.size())
throw std::runtime_error("shape2 must have less dimensions than shape1");
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
shape2_.resize(shape1_.size());
cum_shape2_.resize(shape1_.size());
total_ = 1;
@ -26,8 +26,8 @@ class BroadcastIteratorRight {
shape2_[i] = shape2[i];
}
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
throw std::runtime_error(MakeString(
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]));
ORT_CXX_API_THROW(MakeString(
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
}
}
cum_shape2_[shape2_.size() - 1] = 1;
@ -84,7 +84,7 @@ class BroadcastIteratorRight {
template <typename TCMP>
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
if (pos != 0)
throw std::runtime_error("Not implemented yet.");
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
while (!end()) {
*p3 = cmp(*p1, *p2);
next();

Просмотреть файл

@ -14,8 +14,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
OrtTensorDimensions d_length(ort_, n_elements);
if (d_length.size() != 1)
throw std::runtime_error(MakeString(
"First input must have one dimension not ", d_length.size(), "."));
ORT_CXX_API_THROW(MakeString(
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
int64_t n_els = d_length[0] - 1;
int64_t n_values = p_n_elements[n_els];
std::vector<int64_t> shape{n_values, 2};
@ -110,9 +110,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
for (int64_t i = 0; i < size - 1; ++i) {
pos_end = pos + max_col;
if (pos_end > shape_out_size)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."));
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
dense[pos] = p_values[j];
}
@ -168,9 +168,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
for (int64_t i = 0; i < size - 1; ++i) {
pos_end = pos + max_col;
if (pos_end > shape_out_size)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."));
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
dense[pos] = input[j];
}
@ -210,6 +210,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."));
ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(OrtApi api);

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringRegexReplace : BaseKernel {

Просмотреть файл

@ -3,9 +3,9 @@
#include "string_regex_split.hpp"
#include "string_regex_split_re.hpp"
#include "string_tensor.h"
#include <vector>
#include <cmath>
#include "string_tensor.h"
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
}

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.

Просмотреть файл

@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
OrtTensorDimensions right_dim(ort_, right);
if (left_dim != right_dim) {
throw std::runtime_error("Two input tensor should have the same dimension.");
ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
}
std::vector<std::string> left_value;

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringConcat : BaseKernel {

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringECMARegexReplace : BaseKernel {

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include <regex>
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.

Просмотреть файл

@ -23,9 +23,9 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."));
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
// Setup output
OrtTensorDimensions dimensions(ort_, input);
@ -60,7 +60,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
@ -86,9 +86,9 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."));
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
// Setup output
OrtTensorDimensions dimensions(ort_, input);
@ -123,7 +123,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringHash : BaseKernel {

Просмотреть файл

@ -20,13 +20,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_axis(ort_, input_axis);
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
throw std::runtime_error("Input 3 is the axis, it has 1 element.");
ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X);
if (*axis < 0 || *axis >= dimensions.size())
throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis));
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
if (dimensions.size() > 1) {
@ -89,7 +89,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringJoin : BaseKernel {

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringLength : BaseKernel {

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringLower : BaseKernel {

Просмотреть файл

@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 1)
throw std::runtime_error("Only 1D tensor are supported as input.");
ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
std::vector<std::string> words;
std::vector<int64_t> indices;
@ -116,7 +116,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
@ -132,6 +132,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default:
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
#include "string_utils.h"
struct KernelStringSplit : BaseKernel {

Просмотреть файл

@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
vector_len_ = ParseVectorLen(lines[0]);
if (vector_len_ == 0) {
throw std::runtime_error(MakeString("The mapped value of string input cannot be empty: ", lines[0]));
ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
}
std::vector<int64_t> values(vector_len_);
@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
}
ParseValues(kv[1], values);
@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
auto unk_strs = SplitString(unk, " ", true);
if (unk_strs.size() != vector_len_) {
throw std::runtime_error(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_));
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
}
for (auto& str : unk_strs) {
int64_t value;
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
if (end != str.data() + str.size()) {
throw std::runtime_error(MakeString("Failed to parse unknown_value when processing the number: ", str));
ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
}
unk_value_.push_back(value);
@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
}
auto value_strs = SplitString(kv[1], " ", true);
@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
for (int i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
}
values[i] = value;
}

43
operators/text/text.cc Normal file
Просмотреть файл

@ -0,0 +1,43 @@
#include "text/op_equal.hpp"
#include "text/op_ragged_tensor.hpp"
#include "text/string_hash.hpp"
#include "text/string_join.hpp"
#include "text/string_lower.hpp"
#include "text/string_split.hpp"
#include "text/string_to_vector.hpp"
#include "text/string_upper.hpp"
#include "text/vector_to_string.hpp"
#include "text/string_length.hpp"
#include "text/string_concat.hpp"
#include "text/string_ecmaregex_replace.hpp"
#include "text/string_ecmaregex_split.hpp"
#if defined(ENABLE_RE2_REGEX)
#include "text/re2_strings/string_regex_replace.hpp"
#include "text/re2_strings/string_regex_split.hpp"
#endif // ENABLE_RE2_REGEX
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
LoadCustomOpClasses<CustomOpClassBegin,
#if defined(ENABLE_RE2_REGEX)
CustomOpStringRegexReplace,
CustomOpStringRegexSplitWithOffsets,
#endif // ENABLE_RE2_REGEX
CustomOpRaggedTensorToDense,
CustomOpRaggedTensorToSparse,
CustomOpStringRaggedTensorToDense,
CustomOpStringEqual,
CustomOpStringHash,
CustomOpStringHashFast,
CustomOpStringJoin,
CustomOpStringLower,
CustomOpStringUpper,
CustomOpStringSplit,
CustomOpStringToVector,
CustomOpVectorToString,
CustomOpStringLength,
CustomOpStringConcat,
CustomOpStringECMARegexReplace,
CustomOpStringECMARegexSplitWithOffsets
>;

Просмотреть файл

@ -29,7 +29,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
output_dim = input_dim;
} else {
if (input_dim[input_dim.size() - 1] != vector_len_) {
throw std::runtime_error(MakeString("Incompatible dimension: required vector length should be ", vector_len_));
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
}
output_dim = input_dim;
@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
}
ParseValues(kv[1], values);
@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
}
auto value_strs = SplitString(kv[1], " ", true);
@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
for (int i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
}
values[i] = value;
}

Просмотреть файл

@ -11,13 +11,13 @@
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
if (model_data_.empty()) {
throw std::runtime_error("vocabulary shouldn't be empty.");
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
if (model_ptr == nullptr) {
throw std::runtime_error("Invalid model");
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
}
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
@ -33,7 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
OrtTensorDimensions dimensions(ort_, input);
if (dimensions.Size() != 1 && dimensions[0] != 1) {
throw std::runtime_error("We only support string scalar.");
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
}
std::vector<std::string> input_data;
@ -46,7 +46,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
if (output_length < 0) {
throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed"));
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
}
// inline split output_str by newline '\n'

Просмотреть файл

@ -72,8 +72,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
rows.push_back(indices.size());
} else if (text_index == existing_rows[row_index]) {
if (row_index >= n_existing_rows)
throw std::runtime_error(MakeString(
"row_index=", row_index, " is out of range=", n_existing_rows, "."));
ORT_CXX_API_THROW(MakeString(
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
rows.push_back(indices.size());
++row_index;
}
@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index)
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -4,89 +4,7 @@
#include <set>
#include "onnxruntime_extensions.h"
#include "string_utils.h"
#include "text/op_equal.hpp"
#include "text/op_segment_sum.hpp"
#include "text/op_segment_extraction.hpp"
#include "text/op_ragged_tensor.hpp"
#include "text/string_hash.hpp"
#include "text/string_join.hpp"
#include "text/string_lower.hpp"
#include "text/string_regex_replace.hpp"
#include "text/string_regex_split.hpp"
#include "text/string_split.hpp"
#include "text/string_to_vector.hpp"
#include "text/string_upper.hpp"
#include "text/vector_to_string.hpp"
#include "text/string_length.hpp"
#include "text/string_concat.hpp"
#include "text/string_ecmaregex_replace.hpp"
#include "text/string_ecmaregex_split.hpp"
#ifdef ENABLE_TF_STRING
CustomOpSegmentExtraction c_CustomOpSegmentExtraction;
CustomOpSegmentSum c_CustomOpSegmentSum;
CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
CustomOpStringEqual c_CustomOpStringEqual;
CustomOpStringHash c_CustomOpStringHash;
CustomOpStringHashFast c_CustomOpStringHashFast;
CustomOpStringJoin c_CustomOpStringJoin;
CustomOpStringLower c_CustomOpStringLower;
CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense;
CustomOpStringECMARegexReplace c_CustomOpStringECMARegexReplace;
CustomOpStringECMARegexSplitWithOffsets c_CustomOpStringECMARegexSplitWithOffsets;
CustomOpStringSplit c_CustomOpStringSplit;
CustomOpStringToVector c_CustomOpStringToVector;
CustomOpStringUpper c_CustomOpStringUpper;
CustomOpVectorToString c_CustomOpVectorToString;
CustomOpStringLength c_CustomOpStringLength;
CustomOpStringConcat c_CustomOpStringConcat;
#endif
#ifdef ENABLE_RE2_REGEX
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
#endif
OrtCustomOp* operator_lists[] = {
#ifdef ENABLE_TF_STRING
&c_CustomOpRaggedTensorToDense,
&c_CustomOpRaggedTensorToSparse,
&c_CustomOpSegmentSum,
&c_CustomOpSegmentExtraction,
&c_CustomOpStringEqual,
&c_CustomOpStringHash,
&c_CustomOpStringHashFast,
&c_CustomOpStringJoin,
&c_CustomOpStringLower,
&c_CustomOpStringRaggedTensorToDense,
&c_CustomOpStringECMARegexReplace,
&c_CustomOpStringECMARegexSplitWithOffsets,
&c_CustomOpStringSplit,
&c_CustomOpStringToVector,
&c_CustomOpStringUpper,
&c_CustomOpVectorToString,
&c_CustomOpStringLength,
&c_CustomOpStringConcat,
#endif
#ifdef ENABLE_RE2_REGEX
&c_CustomOpStringRegexReplace,
&c_CustomOpStringRegexSplitWithOffsets,
#endif
nullptr };
#ifdef ENABLE_MATH
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
#endif // ENABLE_MATH
#ifdef ENABLE_TOKENIZER
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
#endif // ENABLE_TOKENIZER
#include "ocos.h"
class ExternalCustomOps {
@ -154,7 +72,10 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
#endif
static std::vector<FxLoadCustomOpFactory> c_factories = {
[]() { return const_cast<const OrtCustomOp**>(operator_lists); }
LoadCustomOpClasses<CustomOpClassBegin>
#if defined(ENABLE_TF_STRING)
, LoadCustomOpClasses_Text
#endif // ENABLE_TF_STRING
#if defined(ENABLE_MATH)
, LoadCustomOpClasses_Math
#endif

Просмотреть файл

@ -3,7 +3,7 @@
#include "gtest/gtest.h"
#include "string_utils.h"
#include "text/string_regex_split_re.hpp"
#include "text/re2_strings/string_regex_split_re.hpp"
#include "text/string_ecmaregex_split.hpp"