support the non-exception compiling for the text domain. (#142)
* support the non-exception compiling for the text domain. * fix an path error.
This commit is contained in:
Родитель
97ec950751
Коммит
2842d2208e
|
@ -5,6 +5,7 @@ build_host_protoc
|
||||||
build_android
|
build_android
|
||||||
build_ios
|
build_ios
|
||||||
build_*
|
build_*
|
||||||
|
_subbuild/
|
||||||
.build_debug/*
|
.build_debug/*
|
||||||
.build_release/*
|
.build_release/*
|
||||||
distribute/*
|
distribute/*
|
||||||
|
|
|
@ -38,7 +38,9 @@ option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file"
|
||||||
|
|
||||||
|
|
||||||
function(disable_all_operators)
|
function(disable_all_operators)
|
||||||
|
set(OCOS_ENABLE_RE2_REGEX OFF CACHE INTERNAL "")
|
||||||
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "")
|
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "")
|
||||||
|
set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "")
|
||||||
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "")
|
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "")
|
||||||
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
|
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
|
||||||
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "")
|
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "")
|
||||||
|
@ -122,6 +124,11 @@ if (OCOS_ENABLE_TF_STRING)
|
||||||
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (OCOS_ENABLE_RE2_REGEX)
|
||||||
|
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
|
||||||
|
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
|
||||||
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_MATH)
|
if (OCOS_ENABLE_MATH)
|
||||||
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
|
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
|
||||||
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
||||||
|
|
|
@ -59,9 +59,6 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct CustomOpClassBegin{
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
class CuopContainer {
|
class CuopContainer {
|
||||||
public:
|
public:
|
||||||
|
@ -86,6 +83,9 @@ class CuopContainer {
|
||||||
std::vector<OrtCustomOp*> ocos_list_;
|
std::vector<OrtCustomOp*> ocos_list_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct CustomOpClassBegin{
|
||||||
|
};
|
||||||
|
|
||||||
typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
|
typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
|
||||||
|
|
||||||
template <typename _Begin_place_holder, typename... Args>
|
template <typename _Begin_place_holder, typename... Args>
|
||||||
|
@ -99,3 +99,15 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
||||||
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
|
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
|
||||||
bool EnablePyCustomOps(bool enable = true);
|
bool EnablePyCustomOps(bool enable = true);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_MATH
|
||||||
|
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
||||||
|
#endif // ENABLE_MATH
|
||||||
|
|
||||||
|
#ifdef ENABLE_TOKENIZER
|
||||||
|
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
|
||||||
|
#endif // ENABLE_TOKENIZER
|
||||||
|
|
||||||
|
#ifdef ENABLE_TF_STRING
|
||||||
|
extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
|
||||||
|
#endif // ENABLE_TF_STRING
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
#include "negpos.hpp"
|
#include "negpos.hpp"
|
||||||
#include "inverse.hpp"
|
#include "inverse.hpp"
|
||||||
|
#include "segment_extraction.hpp"
|
||||||
|
#include "segment_sum.hpp"
|
||||||
|
|
||||||
|
|
||||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = LoadCustomOpClasses<CustomOpClassBegin, CustomOpNegPos, CustomOpInverse>;
|
FxLoadCustomOpFactory LoadCustomOpClasses_Math =
|
||||||
|
LoadCustomOpClasses<CustomOpClassBegin,
|
||||||
|
CustomOpNegPos,
|
||||||
|
CustomOpInverse,
|
||||||
|
CustomOpSegmentExtraction,
|
||||||
|
CustomOpSegmentSum>;
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "op_segment_extraction.hpp"
|
#include "segment_extraction.hpp"
|
||||||
|
|
||||||
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
|
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
|
||||||
}
|
}
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelSegmentExtraction : BaseKernel {
|
struct KernelSegmentExtraction : BaseKernel {
|
|
@ -1,7 +1,7 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "op_segment_sum.hpp"
|
#include "segment_sum.hpp"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelSegmentSum : BaseKernel {
|
struct KernelSegmentSum : BaseKernel {
|
|
@ -1,9 +0,0 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ocos.h"
|
|
||||||
|
|
||||||
|
|
||||||
// TO BE DELETED.
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringEqual : BaseKernel {
|
struct KernelStringEqual : BaseKernel {
|
||||||
|
|
|
@ -13,7 +13,7 @@ class BroadcastIteratorRight {
|
||||||
const std::vector<int64_t>& shape2,
|
const std::vector<int64_t>& shape2,
|
||||||
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
|
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
|
||||||
if (shape2.size() > shape1.size())
|
if (shape2.size() > shape1.size())
|
||||||
throw std::runtime_error("shape2 must have less dimensions than shape1");
|
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
|
||||||
shape2_.resize(shape1_.size());
|
shape2_.resize(shape1_.size());
|
||||||
cum_shape2_.resize(shape1_.size());
|
cum_shape2_.resize(shape1_.size());
|
||||||
total_ = 1;
|
total_ = 1;
|
||||||
|
@ -26,8 +26,8 @@ class BroadcastIteratorRight {
|
||||||
shape2_[i] = shape2[i];
|
shape2_[i] = shape2[i];
|
||||||
}
|
}
|
||||||
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]));
|
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cum_shape2_[shape2_.size() - 1] = 1;
|
cum_shape2_[shape2_.size() - 1] = 1;
|
||||||
|
@ -84,7 +84,7 @@ class BroadcastIteratorRight {
|
||||||
template <typename TCMP>
|
template <typename TCMP>
|
||||||
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
|
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
|
||||||
if (pos != 0)
|
if (pos != 0)
|
||||||
throw std::runtime_error("Not implemented yet.");
|
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
|
||||||
while (!end()) {
|
while (!end()) {
|
||||||
*p3 = cmp(*p1, *p2);
|
*p3 = cmp(*p1, *p2);
|
||||||
next();
|
next();
|
||||||
|
|
|
@ -14,8 +14,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions d_length(ort_, n_elements);
|
OrtTensorDimensions d_length(ort_, n_elements);
|
||||||
|
|
||||||
if (d_length.size() != 1)
|
if (d_length.size() != 1)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"First input must have one dimension not ", d_length.size(), "."));
|
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
|
||||||
int64_t n_els = d_length[0] - 1;
|
int64_t n_els = d_length[0] - 1;
|
||||||
int64_t n_values = p_n_elements[n_els];
|
int64_t n_values = p_n_elements[n_els];
|
||||||
std::vector<int64_t> shape{n_values, 2};
|
std::vector<int64_t> shape{n_values, 2};
|
||||||
|
@ -110,9 +110,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||||
for (int64_t i = 0; i < size - 1; ++i) {
|
for (int64_t i = 0; i < size - 1; ++i) {
|
||||||
pos_end = pos + max_col;
|
pos_end = pos + max_col;
|
||||||
if (pos_end > shape_out_size)
|
if (pos_end > shape_out_size)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||||
" - i=", i, " size=", size, "."));
|
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||||
dense[pos] = p_values[j];
|
dense[pos] = p_values[j];
|
||||||
}
|
}
|
||||||
|
@ -168,9 +168,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||||
for (int64_t i = 0; i < size - 1; ++i) {
|
for (int64_t i = 0; i < size - 1; ++i) {
|
||||||
pos_end = pos + max_col;
|
pos_end = pos + max_col;
|
||||||
if (pos_end > shape_out_size)
|
if (pos_end > shape_out_size)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||||
" - i=", i, " size=", size, "."));
|
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||||
dense[pos] = input[j];
|
dense[pos] = input[j];
|
||||||
}
|
}
|
||||||
|
@ -210,6 +210,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."));
|
ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||||
KernelRaggedTensorToSparse(OrtApi api);
|
KernelRaggedTensorToSparse(OrtApi api);
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringRegexReplace : BaseKernel {
|
struct KernelStringRegexReplace : BaseKernel {
|
|
@ -3,9 +3,9 @@
|
||||||
|
|
||||||
#include "string_regex_split.hpp"
|
#include "string_regex_split.hpp"
|
||||||
#include "string_regex_split_re.hpp"
|
#include "string_regex_split_re.hpp"
|
||||||
|
#include "string_tensor.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "string_tensor.h"
|
|
||||||
|
|
||||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
}
|
}
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
|
@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions right_dim(ort_, right);
|
OrtTensorDimensions right_dim(ort_, right);
|
||||||
|
|
||||||
if (left_dim != right_dim) {
|
if (left_dim != right_dim) {
|
||||||
throw std::runtime_error("Two input tensor should have the same dimension.");
|
ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> left_value;
|
std::vector<std::string> left_value;
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringConcat : BaseKernel {
|
struct KernelStringConcat : BaseKernel {
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringECMARegexReplace : BaseKernel {
|
struct KernelStringECMARegexReplace : BaseKernel {
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||||
|
|
|
@ -23,9 +23,9 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"num_buckets must contain only one element. It has ",
|
"num_buckets must contain only one element. It has ",
|
||||||
num_buckets_dimensions.size(), " dimensions."));
|
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
|
@ -60,7 +60,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -86,9 +86,9 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"num_buckets must contain only one element. It has ",
|
"num_buckets must contain only one element. It has ",
|
||||||
num_buckets_dimensions.size(), " dimensions."));
|
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
|
@ -123,7 +123,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringHash : BaseKernel {
|
struct KernelStringHash : BaseKernel {
|
||||||
|
|
|
@ -20,13 +20,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||||
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
|
ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||||
throw std::runtime_error("Input 3 is the axis, it has 1 element.");
|
ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions(ort_, input_X);
|
OrtTensorDimensions dimensions(ort_, input_X);
|
||||||
if (*axis < 0 || *axis >= dimensions.size())
|
if (*axis < 0 || *axis >= dimensions.size())
|
||||||
throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis));
|
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||||
if (dimensions.size() > 1) {
|
if (dimensions.size() > 1) {
|
||||||
|
@ -89,7 +89,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringJoin : BaseKernel {
|
struct KernelStringJoin : BaseKernel {
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringLength : BaseKernel {
|
struct KernelStringLength : BaseKernel {
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringLower : BaseKernel {
|
struct KernelStringLower : BaseKernel {
|
||||||
|
|
|
@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||||
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||||
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions(ort_, input_X);
|
OrtTensorDimensions dimensions(ort_, input_X);
|
||||||
if (dimensions.size() != 1)
|
if (dimensions.size() != 1)
|
||||||
throw std::runtime_error("Only 1D tensor are supported as input.");
|
ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
std::vector<int64_t> indices;
|
std::vector<int64_t> indices;
|
||||||
|
@ -116,7 +116,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -132,6 +132,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
|
ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "ocos.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringSplit : BaseKernel {
|
struct KernelStringSplit : BaseKernel {
|
||||||
|
|
|
@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
|
|
||||||
vector_len_ = ParseVectorLen(lines[0]);
|
vector_len_ = ParseVectorLen(lines[0]);
|
||||||
if (vector_len_ == 0) {
|
if (vector_len_ == 0) {
|
||||||
throw std::runtime_error(MakeString("The mapped value of string input cannot be empty: ", lines[0]));
|
ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> values(vector_len_);
|
std::vector<int64_t> values(vector_len_);
|
||||||
|
@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseValues(kv[1], values);
|
ParseValues(kv[1], values);
|
||||||
|
@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
||||||
auto unk_strs = SplitString(unk, " ", true);
|
auto unk_strs = SplitString(unk, " ", true);
|
||||||
if (unk_strs.size() != vector_len_) {
|
if (unk_strs.size() != vector_len_) {
|
||||||
throw std::runtime_error(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_));
|
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& str : unk_strs) {
|
for (auto& str : unk_strs) {
|
||||||
int64_t value;
|
int64_t value;
|
||||||
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
||||||
if (end != str.data() + str.size()) {
|
if (end != str.data() + str.size()) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse unknown_value when processing the number: ", str));
|
ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
unk_value_.push_back(value);
|
unk_value_.push_back(value);
|
||||||
|
@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_strs = SplitString(kv[1], " ", true);
|
auto value_strs = SplitString(kv[1], " ", true);
|
||||||
|
@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
for (int i = 0; i < value_strs.size(); i++) {
|
for (int i = 0; i < value_strs.size(); i++) {
|
||||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
|
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
values[i] = value;
|
values[i] = value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
#include "text/op_equal.hpp"
|
||||||
|
#include "text/op_ragged_tensor.hpp"
|
||||||
|
#include "text/string_hash.hpp"
|
||||||
|
#include "text/string_join.hpp"
|
||||||
|
#include "text/string_lower.hpp"
|
||||||
|
#include "text/string_split.hpp"
|
||||||
|
#include "text/string_to_vector.hpp"
|
||||||
|
#include "text/string_upper.hpp"
|
||||||
|
#include "text/vector_to_string.hpp"
|
||||||
|
#include "text/string_length.hpp"
|
||||||
|
#include "text/string_concat.hpp"
|
||||||
|
#include "text/string_ecmaregex_replace.hpp"
|
||||||
|
#include "text/string_ecmaregex_split.hpp"
|
||||||
|
|
||||||
|
#if defined(ENABLE_RE2_REGEX)
|
||||||
|
#include "text/re2_strings/string_regex_replace.hpp"
|
||||||
|
#include "text/re2_strings/string_regex_split.hpp"
|
||||||
|
#endif // ENABLE_RE2_REGEX
|
||||||
|
|
||||||
|
|
||||||
|
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
||||||
|
LoadCustomOpClasses<CustomOpClassBegin,
|
||||||
|
#if defined(ENABLE_RE2_REGEX)
|
||||||
|
CustomOpStringRegexReplace,
|
||||||
|
CustomOpStringRegexSplitWithOffsets,
|
||||||
|
#endif // ENABLE_RE2_REGEX
|
||||||
|
CustomOpRaggedTensorToDense,
|
||||||
|
CustomOpRaggedTensorToSparse,
|
||||||
|
CustomOpStringRaggedTensorToDense,
|
||||||
|
CustomOpStringEqual,
|
||||||
|
CustomOpStringHash,
|
||||||
|
CustomOpStringHashFast,
|
||||||
|
CustomOpStringJoin,
|
||||||
|
CustomOpStringLower,
|
||||||
|
CustomOpStringUpper,
|
||||||
|
CustomOpStringSplit,
|
||||||
|
CustomOpStringToVector,
|
||||||
|
CustomOpVectorToString,
|
||||||
|
CustomOpStringLength,
|
||||||
|
CustomOpStringConcat,
|
||||||
|
CustomOpStringECMARegexReplace,
|
||||||
|
CustomOpStringECMARegexSplitWithOffsets
|
||||||
|
>;
|
|
@ -29,7 +29,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
||||||
output_dim = input_dim;
|
output_dim = input_dim;
|
||||||
} else {
|
} else {
|
||||||
if (input_dim[input_dim.size() - 1] != vector_len_) {
|
if (input_dim[input_dim.size() - 1] != vector_len_) {
|
||||||
throw std::runtime_error(MakeString("Incompatible dimension: required vector length should be ", vector_len_));
|
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
output_dim = input_dim;
|
output_dim = input_dim;
|
||||||
|
@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseValues(kv[1], values);
|
ParseValues(kv[1], values);
|
||||||
|
@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse mapping_table when processing the line: ", line));
|
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_strs = SplitString(kv[1], " ", true);
|
auto value_strs = SplitString(kv[1], " ", true);
|
||||||
|
@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
for (int i = 0; i < value_strs.size(); i++) {
|
for (int i = 0; i < value_strs.size(); i++) {
|
||||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||||
throw std::runtime_error(MakeString("Failed to parse map when processing the number: ", value_strs[i]));
|
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
values[i] = value;
|
values[i] = value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,13 +11,13 @@
|
||||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||||
if (model_data_.empty()) {
|
if (model_data_.empty()) {
|
||||||
throw std::runtime_error("vocabulary shouldn't be empty.");
|
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
|
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
|
||||||
|
|
||||||
if (model_ptr == nullptr) {
|
if (model_ptr == nullptr) {
|
||||||
throw std::runtime_error("Invalid model");
|
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||||
|
@ -33,7 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
|
|
||||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||||
throw std::runtime_error("We only support string scalar.");
|
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> input_data;
|
std::vector<std::string> input_data;
|
||||||
|
@ -46,7 +46,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
|
|
||||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
|
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
|
||||||
if (output_length < 0) {
|
if (output_length < 0) {
|
||||||
throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed"));
|
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
// inline split output_str by newline '\n'
|
// inline split output_str by newline '\n'
|
||||||
|
|
|
@ -72,8 +72,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
||||||
rows.push_back(indices.size());
|
rows.push_back(indices.size());
|
||||||
} else if (text_index == existing_rows[row_index]) {
|
} else if (text_index == existing_rows[row_index]) {
|
||||||
if (row_index >= n_existing_rows)
|
if (row_index >= n_existing_rows)
|
||||||
throw std::runtime_error(MakeString(
|
ORT_CXX_API_THROW(MakeString(
|
||||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."));
|
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
|
||||||
rows.push_back(indices.size());
|
rows.push_back(indices.size());
|
||||||
++row_index;
|
++row_index;
|
||||||
}
|
}
|
||||||
|
@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index)
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
|
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
|
@ -4,89 +4,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "onnxruntime_extensions.h"
|
#include "onnxruntime_extensions.h"
|
||||||
#include "string_utils.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
#include "text/op_equal.hpp"
|
|
||||||
#include "text/op_segment_sum.hpp"
|
|
||||||
#include "text/op_segment_extraction.hpp"
|
|
||||||
#include "text/op_ragged_tensor.hpp"
|
|
||||||
#include "text/string_hash.hpp"
|
|
||||||
#include "text/string_join.hpp"
|
|
||||||
#include "text/string_lower.hpp"
|
|
||||||
#include "text/string_regex_replace.hpp"
|
|
||||||
#include "text/string_regex_split.hpp"
|
|
||||||
#include "text/string_split.hpp"
|
|
||||||
#include "text/string_to_vector.hpp"
|
|
||||||
#include "text/string_upper.hpp"
|
|
||||||
#include "text/vector_to_string.hpp"
|
|
||||||
#include "text/string_length.hpp"
|
|
||||||
#include "text/string_concat.hpp"
|
|
||||||
#include "text/string_ecmaregex_replace.hpp"
|
|
||||||
#include "text/string_ecmaregex_split.hpp"
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef ENABLE_TF_STRING
|
|
||||||
CustomOpSegmentExtraction c_CustomOpSegmentExtraction;
|
|
||||||
CustomOpSegmentSum c_CustomOpSegmentSum;
|
|
||||||
CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
|
|
||||||
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
|
|
||||||
CustomOpStringEqual c_CustomOpStringEqual;
|
|
||||||
CustomOpStringHash c_CustomOpStringHash;
|
|
||||||
CustomOpStringHashFast c_CustomOpStringHashFast;
|
|
||||||
CustomOpStringJoin c_CustomOpStringJoin;
|
|
||||||
CustomOpStringLower c_CustomOpStringLower;
|
|
||||||
CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense;
|
|
||||||
CustomOpStringECMARegexReplace c_CustomOpStringECMARegexReplace;
|
|
||||||
CustomOpStringECMARegexSplitWithOffsets c_CustomOpStringECMARegexSplitWithOffsets;
|
|
||||||
CustomOpStringSplit c_CustomOpStringSplit;
|
|
||||||
CustomOpStringToVector c_CustomOpStringToVector;
|
|
||||||
CustomOpStringUpper c_CustomOpStringUpper;
|
|
||||||
CustomOpVectorToString c_CustomOpVectorToString;
|
|
||||||
CustomOpStringLength c_CustomOpStringLength;
|
|
||||||
CustomOpStringConcat c_CustomOpStringConcat;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef ENABLE_RE2_REGEX
|
|
||||||
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
|
||||||
CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
OrtCustomOp* operator_lists[] = {
|
|
||||||
#ifdef ENABLE_TF_STRING
|
|
||||||
&c_CustomOpRaggedTensorToDense,
|
|
||||||
&c_CustomOpRaggedTensorToSparse,
|
|
||||||
&c_CustomOpSegmentSum,
|
|
||||||
&c_CustomOpSegmentExtraction,
|
|
||||||
&c_CustomOpStringEqual,
|
|
||||||
&c_CustomOpStringHash,
|
|
||||||
&c_CustomOpStringHashFast,
|
|
||||||
&c_CustomOpStringJoin,
|
|
||||||
&c_CustomOpStringLower,
|
|
||||||
&c_CustomOpStringRaggedTensorToDense,
|
|
||||||
&c_CustomOpStringECMARegexReplace,
|
|
||||||
&c_CustomOpStringECMARegexSplitWithOffsets,
|
|
||||||
&c_CustomOpStringSplit,
|
|
||||||
&c_CustomOpStringToVector,
|
|
||||||
&c_CustomOpStringUpper,
|
|
||||||
&c_CustomOpVectorToString,
|
|
||||||
&c_CustomOpStringLength,
|
|
||||||
&c_CustomOpStringConcat,
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef ENABLE_RE2_REGEX
|
|
||||||
&c_CustomOpStringRegexReplace,
|
|
||||||
&c_CustomOpStringRegexSplitWithOffsets,
|
|
||||||
#endif
|
|
||||||
|
|
||||||
nullptr };
|
|
||||||
|
|
||||||
#ifdef ENABLE_MATH
|
|
||||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
|
||||||
#endif // ENABLE_MATH
|
|
||||||
|
|
||||||
#ifdef ENABLE_TOKENIZER
|
|
||||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
|
|
||||||
#endif // ENABLE_TOKENIZER
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalCustomOps {
|
class ExternalCustomOps {
|
||||||
|
@ -154,7 +72,10 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::vector<FxLoadCustomOpFactory> c_factories = {
|
static std::vector<FxLoadCustomOpFactory> c_factories = {
|
||||||
[]() { return const_cast<const OrtCustomOp**>(operator_lists); }
|
LoadCustomOpClasses<CustomOpClassBegin>
|
||||||
|
#if defined(ENABLE_TF_STRING)
|
||||||
|
, LoadCustomOpClasses_Text
|
||||||
|
#endif // ENABLE_TF_STRING
|
||||||
#if defined(ENABLE_MATH)
|
#if defined(ENABLE_MATH)
|
||||||
, LoadCustomOpClasses_Math
|
, LoadCustomOpClasses_Math
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
#include "text/string_regex_split_re.hpp"
|
#include "text/re2_strings/string_regex_split_re.hpp"
|
||||||
#include "text/string_ecmaregex_split.hpp"
|
#include "text/string_ecmaregex_split.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче