Make kernel Compute method implementations const (#500)
* Nodes can be called concurrently and Compute needs to be stateless due to that. Update the kernels to make Compute const. * Fix test that uses ustring.h. Would be better to not have duplicate declarations for GetTensorMutableDataString and FillTensorDataString in ustring.h and string_tensor.h.
This commit is contained in:
Родитель
b8bac85ecd
Коммит
e448676a5e
|
@ -4,7 +4,7 @@
|
|||
#include "string_utils.h"
|
||||
#include "ustring.h"
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output) {
|
||||
(void)context;
|
||||
OrtTensorDimensions dimensions(ort, value);
|
||||
|
@ -23,7 +23,7 @@ void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKe
|
|||
}
|
||||
}
|
||||
|
||||
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const std::vector<std::string>& value, OrtValue* output) {
|
||||
(void)ort;
|
||||
(void)context;
|
||||
|
@ -32,11 +32,11 @@ void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelCo
|
|||
temp[i] = value[i].c_str();
|
||||
}
|
||||
|
||||
OrtW::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size()));
|
||||
OrtW::ThrowOnError(api, api.FillStringTensor(output, temp.data(), value.size()));
|
||||
}
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<ustring>& output) {
|
||||
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<ustring>& output) {
|
||||
std::vector<std::string> utf8_strings;
|
||||
GetTensorMutableDataString(api, ort, context, value, utf8_strings);
|
||||
|
||||
|
@ -46,12 +46,11 @@ void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKe
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const std::vector<ustring>& value, OrtValue* output) {
|
||||
std::vector<std::string> utf8_strings;
|
||||
utf8_strings.reserve(value.size());
|
||||
for (const auto& str: value) {
|
||||
for (const auto& str : value) {
|
||||
utf8_strings.push_back(std::string(str));
|
||||
}
|
||||
FillTensorDataString(api, ort, context, utf8_strings, output);
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
#include "ocos.h"
|
||||
#include <string>
|
||||
|
||||
|
||||
// Retrieves a vector of strings if the input type is std::string.
|
||||
// It is a copy of the input data and can be modified to compute the output.
|
||||
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output);
|
||||
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const std::vector<std::string>& value, OrtValue* output);
|
||||
|
|
|
@ -45,9 +45,8 @@ struct hash<ustring> {
|
|||
};
|
||||
} // namespace std
|
||||
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<ustring>& output);
|
||||
|
||||
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
|
||||
const std::vector<ustring>& value, OrtValue* output);
|
||||
|
|
|
@ -105,9 +105,9 @@ class Tensor : public TensorBase {
|
|||
type_ = api_.GetTensorElementType(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
const OrtMemoryInfo* mem_info = {};
|
||||
api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info);
|
||||
api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
|
||||
if (mem_info) {
|
||||
api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_);
|
||||
api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -856,7 +856,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
|
|||
template <typename CustomOp>
|
||||
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
||||
template <typename... Args>
|
||||
using CustomComputeFn = void (CustomOp::*)(Args...);
|
||||
using CustomComputeFn = void (CustomOp::*)(Args...) const;
|
||||
using MyType = OrtLiteCustomStruct<CustomOp>;
|
||||
|
||||
struct Kernel {
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
|
||||
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
|
||||
|
||||
|
||||
namespace OrtW {
|
||||
|
||||
//
|
||||
|
|
|
@ -38,7 +38,7 @@ struct AudioDecoder : public BaseKernel {
|
|||
kFLAC
|
||||
};
|
||||
|
||||
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format) {
|
||||
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format) const {
|
||||
static const std::map<std::string, AudioStreamType> format_mapping = {
|
||||
{"default", AudioStreamType::kDefault},
|
||||
{"wav", AudioStreamType::kWAV},
|
||||
|
@ -98,7 +98,7 @@ struct AudioDecoder : public BaseKernel {
|
|||
|
||||
void Compute(const ortc::Tensor<uint8_t>& input,
|
||||
const std::optional<std::string> format,
|
||||
ortc::Tensor<float>& output0) {
|
||||
ortc::Tensor<float>& output0) const {
|
||||
const uint8_t* p_data = input.Data();
|
||||
auto input_dim = input.Shape();
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
|
@ -146,7 +146,7 @@ struct AudioDecoder : public BaseKernel {
|
|||
}
|
||||
|
||||
if (downsample_rate_ != 0 &&
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& inf
|
|||
binary_type_ = TryToGetAttributeWithDefault<std::string>(kBinaryType, "");
|
||||
}
|
||||
|
||||
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) {
|
||||
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
|
||||
|
@ -196,7 +196,8 @@ void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::
|
|||
AzureTextInvoker::AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
|
||||
}
|
||||
|
||||
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) {
|
||||
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
CurlHandler curl_handler(WriteStringCallback);
|
||||
StringBuffer string_buffer;
|
||||
|
||||
|
@ -313,8 +314,7 @@ int8_t* CreateNonStrTensor(const std::string& data_type,
|
|||
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
|
||||
}
|
||||
|
||||
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs,
|
||||
ortc::Variadic& outputs) {
|
||||
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
|
||||
|
|
|
@ -20,7 +20,7 @@ struct AzureInvoker : public BaseKernel {
|
|||
|
||||
struct AzureAudioInvoker : public AzureInvoker {
|
||||
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
std::string binary_type_;
|
||||
|
@ -28,7 +28,7 @@ struct AzureAudioInvoker : public AzureInvoker {
|
|||
|
||||
struct AzureTextInvoker : public AzureInvoker {
|
||||
AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output);
|
||||
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
std::string binary_type_;
|
||||
|
@ -36,7 +36,7 @@ struct AzureTextInvoker : public AzureInvoker {
|
|||
|
||||
struct AzureTritonInvoker : public AzureInvoker {
|
||||
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
|
||||
|
|
|
@ -18,7 +18,7 @@ struct STFT : public BaseKernel {
|
|||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) {
|
||||
ortc::Tensor<float>& output0) const {
|
||||
auto X = input0.Data();
|
||||
auto window = input3.data();
|
||||
auto dimensions = input0.Shape();
|
||||
|
@ -77,7 +77,7 @@ struct StftNormal : public STFT {
|
|||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) {
|
||||
ortc::Tensor<float>& output0) const {
|
||||
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -11,6 +11,6 @@ KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& inf
|
|||
void KernelStringEqual::Compute(OrtKernelContext* context,
|
||||
const ortc::Tensor<std::string>&,
|
||||
const ortc::Tensor<std::string>&,
|
||||
ortc::Tensor<bool>& output) {
|
||||
ortc::Tensor<bool>& output) const {
|
||||
KernelEqual_Compute<std::string>(api_, ort_, context);
|
||||
}
|
||||
|
|
|
@ -11,5 +11,5 @@ struct KernelStringEqual : BaseKernel {
|
|||
void Compute(OrtKernelContext* context,
|
||||
const ortc::Tensor<std::string>&,
|
||||
const ortc::Tensor<std::string>&,
|
||||
ortc::Tensor<bool>& output);
|
||||
ortc::Tensor<bool>& output) const;
|
||||
};
|
||||
|
|
|
@ -27,7 +27,8 @@ class BroadcastIteratorRight {
|
|||
}
|
||||
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
|
||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
cum_shape2_[shape2_.size() - 1] = 1;
|
||||
|
@ -114,7 +115,7 @@ inline bool Compare<std::string>::operator()(const std::string& s1, const std::s
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||
void KernelEqual_Compute(const OrtApi& api, const OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const T* X = ort_.GetTensorData<T>(input_X);
|
||||
|
@ -144,7 +145,8 @@ void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelCo
|
|||
}
|
||||
|
||||
template <>
|
||||
void KernelEqual_Compute<std::string>(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||
void KernelEqual_Compute<std::string>(const OrtApi& api, const OrtW::CustomOpApi& ort_,
|
||||
OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
#include "op_ragged_tensor.hpp"
|
||||
|
||||
void KernelRaggedTensoroSparse::Compute(const ortc::Tensor<int64_t>& n_element,
|
||||
ortc::Tensor<int64_t>& output_0,
|
||||
ortc::Tensor<int64_t>& output_1) {
|
||||
ortc::Tensor<int64_t>& output_0,
|
||||
ortc::Tensor<int64_t>& output_1) const {
|
||||
const int64_t* p_n_elements = n_element.Data();
|
||||
|
||||
auto& d_length = n_element.Shape();
|
||||
|
@ -42,14 +42,15 @@ CommonRaggedTensoroDense::CommonRaggedTensoroDense(const OrtApi& api, const OrtK
|
|||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs,
|
||||
OrtTensorDimensions* dims) const {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
inputs[i] = ort_.KernelContext_GetInput(context, i);
|
||||
dims[i] = OrtTensorDimensions(ort_, inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
|
||||
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) const {
|
||||
int64_t size = n;
|
||||
int64_t max_col = 0;
|
||||
for (int64_t i = 1; i < size; ++i) {
|
||||
|
@ -64,10 +65,10 @@ KernelRaggedTensoroDense::KernelRaggedTensoroDense(const OrtApi& api, const OrtK
|
|||
}
|
||||
|
||||
void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<int64_t>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<int64_t>& input3,
|
||||
ortc::Tensor<int64_t>& output) {
|
||||
const ortc::Tensor<int64_t>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<int64_t>& input3,
|
||||
ortc::Tensor<int64_t>& output) const {
|
||||
const int64_t* p_values = input1.Data();
|
||||
const int64_t* p_missing = input2.Data();
|
||||
const int64_t* p_indices = input3.Data();
|
||||
|
@ -97,14 +98,15 @@ void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
|
|||
}
|
||||
}
|
||||
|
||||
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensoroDense(api, info) {
|
||||
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CommonRaggedTensoroDense(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<std::string>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<std::string>& input3,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
const ortc::Tensor<std::string>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<std::string>& input3,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
// const OrtValue* inputs[4];
|
||||
// OrtTensorDimensions dims[4];
|
||||
|
||||
|
|
|
@ -11,15 +11,15 @@ struct KernelRaggedTensoroSparse : BaseKernel {
|
|||
|
||||
void Compute(const ortc::Tensor<int64_t>& n_element,
|
||||
ortc::Tensor<int64_t>& output_0,
|
||||
ortc::Tensor<int64_t>& output_1);
|
||||
ortc::Tensor<int64_t>& output_1) const;
|
||||
};
|
||||
|
||||
struct CommonRaggedTensoroDense : BaseKernel {
|
||||
CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
protected:
|
||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
||||
int64_t GetMaxCol(int64_t n, const int64_t* p_indices);
|
||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) const;
|
||||
int64_t GetMaxCol(int64_t n, const int64_t* p_indices) const;
|
||||
};
|
||||
|
||||
struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
|
||||
|
@ -28,7 +28,7 @@ struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
|
|||
const ortc::Tensor<int64_t>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<int64_t>& input3,
|
||||
ortc::Tensor<int64_t>& output);
|
||||
ortc::Tensor<int64_t>& output) const;
|
||||
|
||||
private:
|
||||
int64_t missing_value_;
|
||||
|
@ -40,5 +40,5 @@ struct KernelStringRaggedTensoroDense : CommonRaggedTensoroDense {
|
|||
const ortc::Tensor<std::string>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<std::string>& input3,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
};
|
||||
|
|
|
@ -10,13 +10,13 @@
|
|||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
|
||||
global_replace_ = TryToGetAttributeWithDefault("global_replace", 1);
|
||||
}
|
||||
|
||||
void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
std::string_view str_rewrite,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
std::string_view str_pattern,
|
||||
std::string_view str_rewrite,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
if (str_pattern.empty())
|
||||
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
|
@ -38,4 +38,4 @@ void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
|
|||
}
|
||||
}
|
||||
output.SetStringOutput(str_input, dim);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,8 +11,8 @@ struct KernelStringRegexReplace : BaseKernel {
|
|||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
std::string_view str_rewrite,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
protected:
|
||||
int64_t global_replace_;
|
||||
};
|
||||
};
|
||||
|
|
|
@ -16,7 +16,7 @@ KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, co
|
|||
void KernelStringECMARegexReplace::Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view rewrite,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
// make a copy as input is constant;
|
||||
std::vector<std::string> str_input = input.Data();
|
||||
if (pattern.empty()) {
|
||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
|
|||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view rewrite,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
protected:
|
||||
bool global_replace_;
|
||||
|
|
|
@ -21,7 +21,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(const ortc::Tensor<std::stri
|
|||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
ortc::Tensor<int64_t>& output3) {
|
||||
ortc::Tensor<int64_t>& output3) const {
|
||||
// Setup inputs
|
||||
auto& str_input = input.Data();
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
|||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
ortc::Tensor<int64_t>& output3);
|
||||
ortc::Tensor<int64_t>& output3) const;
|
||||
|
||||
private:
|
||||
bool ignore_case_;
|
||||
|
|
|
@ -15,20 +15,23 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo&
|
|||
auto items = SplitString(line, "\t", true);
|
||||
|
||||
if (items.size() != 2) {
|
||||
ORTX_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW(
|
||||
"[StringMapping]: Should only exist two items in one line, find error in line: " + std::string(line),
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
map_[std::string(items[0])] = std::string(items[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelStringMapping::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
// make a copy as input is constant
|
||||
std::vector<std::string> input_data = input.Data();
|
||||
|
||||
for (auto& str : input_data) {
|
||||
if (map_.find(str) != map_.end()) {
|
||||
str = map_[str];
|
||||
auto entry = map_.find(str);
|
||||
if (entry != map_.end()) {
|
||||
str = entry->second;
|
||||
}
|
||||
}
|
||||
output.SetStringOutput(input_data, input.Shape());
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
struct KernelStringMapping : BaseKernel {
|
||||
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
|
|
|
@ -11,7 +11,7 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
|
|||
|
||||
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(const std::vector<std::string>& str_input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim) {
|
||||
std::vector<int64_t>& output_dim) const {
|
||||
std::vector<std::vector<int64_t>> result;
|
||||
|
||||
// Set output dimension
|
||||
|
@ -42,7 +42,8 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
|
||||
vector_len_ = ParseVectorLen(lines[0]);
|
||||
if (vector_len_ == 0) {
|
||||
ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<int64_t> values(vector_len_);
|
||||
|
@ -50,7 +51,8 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
ParseValues(kv[1], values);
|
||||
|
@ -63,14 +65,17 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
|||
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
||||
auto unk_strs = SplitString(unk, " ", true);
|
||||
if (unk_strs.size() != vector_len_) {
|
||||
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(
|
||||
MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
for (auto& str : unk_strs) {
|
||||
int64_t value;
|
||||
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
||||
if (end != str.data() + str.size()) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
unk_value_.push_back(value);
|
||||
|
@ -81,7 +86,8 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto value_strs = SplitString(kv[1], " ", true);
|
||||
|
@ -95,7 +101,8 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
values[i] = value;
|
||||
}
|
||||
|
@ -110,7 +117,7 @@ KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInf
|
|||
}
|
||||
|
||||
void KernelStringToVector::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& out) {
|
||||
ortc::Tensor<int64_t>& out) const {
|
||||
// Setup input
|
||||
auto& input_data = input.Data();
|
||||
// Get output
|
||||
|
|
|
@ -14,7 +14,7 @@ class StringToVectorImpl {
|
|||
StringToVectorImpl(std::string& map, std::string& unk);
|
||||
std::vector<std::vector<int64_t>> Compute(const std::vector<std::string>& str_input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim);
|
||||
std::vector<int64_t>& output_dim) const;
|
||||
|
||||
private:
|
||||
void ParseMappingTable(std::string& map);
|
||||
|
@ -32,7 +32,7 @@ class StringToVectorImpl {
|
|||
struct KernelStringToVector : BaseKernel {
|
||||
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& out);
|
||||
ortc::Tensor<int64_t>& out) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<StringToVectorImpl> impl_;
|
||||
|
|
|
@ -20,7 +20,7 @@ VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk
|
|||
|
||||
std::vector<std::string> VectorToStringImpl::Compute(const void* input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim) {
|
||||
std::vector<int64_t>& output_dim) const {
|
||||
std::vector<std::string> result;
|
||||
|
||||
const int64_t* ptr = static_cast<const int64_t*>(input);
|
||||
|
@ -30,7 +30,8 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input,
|
|||
output_dim = input_dim;
|
||||
} else {
|
||||
if (input_dim.empty() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
||||
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
output_dim = input_dim;
|
||||
|
@ -72,7 +73,8 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
ParseValues(kv[1], values);
|
||||
|
@ -85,7 +87,8 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
|
|||
auto kv = SplitString(line, "\t", true);
|
||||
|
||||
if (kv.size() != 2) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto value_strs = SplitString(kv[1], " ", true);
|
||||
|
@ -99,7 +102,8 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
values[i] = value;
|
||||
}
|
||||
|
@ -114,7 +118,7 @@ KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInf
|
|||
}
|
||||
|
||||
void KernelVectorToString::Compute(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<std::string>& out) {
|
||||
ortc::Tensor<std::string>& out) const {
|
||||
const void* input_data = input.Data();
|
||||
|
||||
std::vector<int64_t> output_dim;
|
||||
|
|
|
@ -22,7 +22,7 @@ class VectorToStringImpl {
|
|||
VectorToStringImpl(std::string& map, std::string& unk);
|
||||
std::vector<std::string> Compute(const void* input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim);
|
||||
std::vector<int64_t>& output_dim) const;
|
||||
|
||||
private:
|
||||
void ParseMappingTable(std::string& map);
|
||||
|
@ -37,7 +37,7 @@ class VectorToStringImpl {
|
|||
struct KernelVectorToString : BaseKernel {
|
||||
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<std::string>& out);
|
||||
ortc::Tensor<std::string>& out) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<VectorToStringImpl> impl_;
|
||||
|
|
|
@ -88,11 +88,12 @@ KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInf
|
|||
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
|
||||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
||||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
|
||||
tokenize_punctuation, remove_control_chars);
|
||||
}
|
||||
|
||||
void KernelBasicTokenizer::Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
// Setup inputs
|
||||
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
|
||||
output.SetStringOutput({result[0].operator std::string()}, {1});
|
||||
|
|
|
@ -24,7 +24,7 @@ class BasicTokenizer {
|
|||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
|
|
|
@ -331,7 +331,7 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
// Setup inputs
|
||||
auto& input_data = input.Data();
|
||||
|
||||
|
@ -343,7 +343,7 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
std::list<OffsetMappingType> offset_map;
|
||||
|
||||
// Only compute offset mapping if optional output for it exists.
|
||||
compute_offset_mapping = false;
|
||||
bool compute_offset_mapping = false;
|
||||
if (offset_mapping.has_value()) {
|
||||
compute_offset_mapping = true;
|
||||
}
|
||||
|
@ -397,7 +397,7 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
// Setup inputs
|
||||
auto& input_data = input.Data();
|
||||
|
||||
|
@ -408,7 +408,7 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
std::list<OffsetMappingType> offset_map;
|
||||
|
||||
// Only compute offset mapping if optional output for it exists.
|
||||
compute_offset_mapping = false;
|
||||
bool compute_offset_mapping = false;
|
||||
if (offset_mapping.has_value()) {
|
||||
compute_offset_mapping = true;
|
||||
}
|
||||
|
|
|
@ -46,8 +46,10 @@ class WordpieceTokenizer final {
|
|||
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
|
||||
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
|
||||
bool compute_offset_mapping);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map,
|
||||
bool compute_offset_mapping);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
|
||||
private:
|
||||
|
@ -67,7 +69,8 @@ class BertTokenizer final {
|
|||
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
|
||||
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
|
||||
bool compute_offset_mapping);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
|
||||
void Truncate(std::vector<int64_t>& ids);
|
||||
|
@ -98,9 +101,8 @@ struct KernelBertTokenizer : BaseKernel {
|
|||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
bool compute_offset_mapping;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<BertTokenizer> tokenizer_;
|
||||
|
@ -113,5 +115,5 @@ struct KernelHfBertTokenizer : KernelBertTokenizer {
|
|||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
};
|
||||
|
|
|
@ -138,7 +138,7 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const
|
|||
|
||||
void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
|
||||
const ortc::Tensor<int64_t>& positions,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
auto& ids_dim = ids.Shape();
|
||||
|
||||
|
@ -151,7 +151,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
|
|||
if (use_indices_ &&
|
||||
(!((positions.NumberOfElement() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices",
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
const int64_t* p_positions = positions.NumberOfElement() == 0 ? nullptr : positions.Data();
|
||||
|
@ -159,7 +160,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
|
|||
std::vector<std::string> result;
|
||||
std::vector<int64_t> output_dim(1);
|
||||
if (!use_indices_) {
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()),
|
||||
skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
output_dim[0] = 1;
|
||||
} else {
|
||||
if (p_positions != nullptr) {
|
||||
|
@ -167,7 +169,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
|
|||
int64_t start = p_positions[2 * i];
|
||||
int64_t end = p_positions[2 * i + 1];
|
||||
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end),
|
||||
skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
}
|
||||
output_dim[0] = positions_dim[0];
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
|||
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
const ortc::Tensor<int64_t>& positions,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||
|
|
|
@ -16,7 +16,8 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
|
||||
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()),
|
||||
static_cast<int>(model_data_.size()));
|
||||
|
||||
if (model_ptr == nullptr) {
|
||||
ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
||||
|
@ -28,11 +29,13 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
}
|
||||
|
||||
void KernelBlingFireSentenceBreaker::Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
int max_length = static_cast<int>(2 * input.size() + 1);
|
||||
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
|
||||
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()),
|
||||
output_str.get(), nullptr, nullptr, max_length,
|
||||
model_.get());
|
||||
if (output_length < 0) {
|
||||
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
|||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
private:
|
||||
using ModelPtr = std::shared_ptr<void>;
|
||||
|
|
|
@ -100,7 +100,7 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
}
|
||||
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
|
@ -181,4 +181,4 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
std::map<char32_t, unsigned char> byte_decoder_;
|
||||
std::map<int64_t, std::string> added_tokens_;
|
||||
std::set<int64_t> all_special_ids_;
|
||||
};
|
||||
};
|
||||
|
|
|
@ -31,8 +31,10 @@ KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKerne
|
|||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map) {
|
||||
std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const {
|
||||
std::vector<int64_t> res;
|
||||
std::list<std::pair<int, int>> byte_list;
|
||||
|
||||
WhiteSpaceClean(input);
|
||||
|
||||
|
@ -87,21 +89,21 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
|
|||
utf8_token.erase(std::remove(utf8_token.begin(), utf8_token.end(), ' '), utf8_token.end());
|
||||
|
||||
// Get byte encodings prior to performing BPE
|
||||
byte_list_.clear();
|
||||
byte_list.clear();
|
||||
for (int i = 0; i < utf8_token.length(); i++) {
|
||||
if (i == utf8_token.length() - 1) {
|
||||
std::string boundary(1, utf8_token[i]);
|
||||
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->GetEncoding(boundary + "</w>"), 1));
|
||||
byte_list.push_back(std::make_pair(bbpe_tokenizer_->GetEncoding(boundary + "</w>"), 1));
|
||||
} else {
|
||||
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])], 1));
|
||||
byte_list.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])], 1));
|
||||
}
|
||||
}
|
||||
|
||||
// Perform BPE
|
||||
bbpe_tokenizer_->bpe(byte_list_);
|
||||
bbpe_tokenizer_->bpe(byte_list);
|
||||
|
||||
// Add output to result
|
||||
for (auto p : byte_list_) {
|
||||
for (auto p : byte_list) {
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
@ -131,7 +133,7 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
|
|||
void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
// Setup inputs
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
std::list<OffsetMappingType> offset_map;
|
||||
|
@ -140,14 +142,15 @@ void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
|
||||
// Only compute offset mapping if optional output for it exists.
|
||||
compute_offset_mapping = false;
|
||||
bool compute_offset_mapping = false;
|
||||
if (offset_mapping.has_value()) {
|
||||
compute_offset_mapping = true;
|
||||
}
|
||||
|
||||
for (auto& str : str_input) {
|
||||
ustring ustr = ustring(str);
|
||||
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_, offset_map));
|
||||
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_,
|
||||
compute_offset_mapping, offset_map));
|
||||
}
|
||||
|
||||
size_t max_length = 0;
|
||||
|
|
|
@ -9,14 +9,13 @@ struct KernelClipBpeTokenizer : BaseKernel {
|
|||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
bool compute_offset_mapping;
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
private:
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map);
|
||||
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const;
|
||||
|
||||
int64_t padding_length_;
|
||||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
|
|
@ -30,8 +30,9 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& i
|
|||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t max_length) {
|
||||
std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t max_length) const {
|
||||
std::vector<int64_t> res;
|
||||
std::list<std::pair<int, int>> byte_list;
|
||||
|
||||
if (IsEmptyUString(input)) {
|
||||
return res;
|
||||
|
@ -59,14 +60,14 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
|
||||
std::string utf8_token = std::string(ustring(tok));
|
||||
|
||||
byte_list_.clear();
|
||||
byte_list.clear();
|
||||
for (char& cp : utf8_token) {
|
||||
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
|
||||
byte_list.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
|
||||
}
|
||||
|
||||
bbpe_tokenizer_->bpe(byte_list_);
|
||||
bbpe_tokenizer_->bpe(byte_list);
|
||||
|
||||
for (auto p : byte_list_) {
|
||||
for (auto p : byte_list) {
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
@ -81,7 +82,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
|
||||
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask) {
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask) const {
|
||||
// Setup inputs
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
const auto& input_dim = input.Shape();
|
||||
|
|
|
@ -8,12 +8,11 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask);
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask) const;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length);
|
||||
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length) const;
|
||||
|
||||
int64_t padding_length_;
|
||||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
|
|
@ -31,8 +31,11 @@ KernelRobertaBpeTokenizer::KernelRobertaBpeTokenizer(const OrtApi& api, const Or
|
|||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map) {
|
||||
std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t max_length,
|
||||
bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const {
|
||||
std::vector<int64_t> res;
|
||||
std::list<std::pair<int, int>> byte_list;
|
||||
|
||||
if (IsEmptyUString(input)) {
|
||||
return res;
|
||||
|
@ -81,16 +84,16 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
|
|||
}
|
||||
|
||||
// Get byte encodings prior to performing BPE
|
||||
byte_list_.clear();
|
||||
byte_list.clear();
|
||||
for (char& cp : utf8_token) {
|
||||
byte_list_.emplace_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
|
||||
byte_list.emplace_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
|
||||
}
|
||||
|
||||
// Perform BPE
|
||||
bbpe_tokenizer_->bpe(byte_list_);
|
||||
bbpe_tokenizer_->bpe(byte_list);
|
||||
|
||||
// Add output to result
|
||||
for (auto p : byte_list_) {
|
||||
for (auto p : byte_list) {
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
@ -119,7 +122,7 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
|
|||
void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
// Setup inputs
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
std::list<OffsetMappingType> offset_map;
|
||||
|
@ -128,14 +131,15 @@ void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
|
||||
// Only compute offset mapping if optional output for it exists.
|
||||
compute_offset_mapping = false;
|
||||
bool compute_offset_mapping = false;
|
||||
if (offset_mapping.has_value()) {
|
||||
compute_offset_mapping = true;
|
||||
}
|
||||
|
||||
for (auto& str : str_input) {
|
||||
ustring ustr = ustring(str);
|
||||
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_, offset_map));
|
||||
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_,
|
||||
compute_offset_mapping, offset_map));
|
||||
}
|
||||
|
||||
size_t max_length = 0;
|
||||
|
|
|
@ -9,14 +9,13 @@ struct KernelRobertaBpeTokenizer : BaseKernel {
|
|||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
bool compute_offset_mapping;
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
private:
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map);
|
||||
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const;
|
||||
|
||||
int64_t padding_length_;
|
||||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
|
|
@ -23,7 +23,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
}
|
||||
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
auto& ids_dim = ids.Shape();
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ void KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::string>& inpu
|
|||
bool add_eos,
|
||||
bool add_rev,
|
||||
ortc::Tensor<int32_t>& output,
|
||||
ortc::Tensor<int64_t>& output1) {
|
||||
ortc::Tensor<int64_t>& output1) const {
|
||||
// Update with the new API
|
||||
auto& str_input = input.Data();
|
||||
// computation
|
||||
|
|
|
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
|
|||
bool add_eos,
|
||||
bool add_rev,
|
||||
ortc::Tensor<int32_t>& output,
|
||||
ortc::Tensor<int64_t>& output1);
|
||||
ortc::Tensor<int64_t>& output1) const;
|
||||
|
||||
private:
|
||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||
|
|
|
@ -127,7 +127,7 @@ void KernelWordpieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
ortc::Tensor<std::string>& output,
|
||||
ortc::Tensor<int64_t>& row_lengths,
|
||||
ortc::Tensor<int64_t>& out_row_begin,
|
||||
ortc::Tensor<int64_t>& output_limit_values) {
|
||||
ortc::Tensor<int64_t>& output_limit_values) const {
|
||||
// Update with the new API
|
||||
// make a copy as we need ustring
|
||||
std::vector<ustring> str_input;
|
||||
|
|
|
@ -17,7 +17,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
|||
ortc::Tensor<std::string>& output,
|
||||
ortc::Tensor<int64_t>& row_lengths,
|
||||
ortc::Tensor<int64_t>& out_row_begin,
|
||||
ortc::Tensor<int64_t>& output_limit_values);
|
||||
ortc::Tensor<int64_t>& output_limit_values) const;
|
||||
|
||||
private:
|
||||
int64_t max_input_chars_per_word_;
|
||||
|
|
|
@ -8,8 +8,7 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const {
|
||||
// Setup inputs
|
||||
const auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 1ULL) {
|
||||
|
|
|
@ -15,7 +15,7 @@ void decode_image(const ortc::Tensor<uint8_t>& input,
|
|||
|
||||
struct KernelDecodeImage : BaseKernel {
|
||||
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output);
|
||||
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -153,7 +153,7 @@ void DrawBox(ImageView& image, gsl::span<const float> box, BoundingBoxFormat bbo
|
|||
if (thickness < 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// If not all filled
|
||||
if (thickness != (std::min(x_end - x_start, y_end - y_start))) {
|
||||
auto offset = thickness / 2;
|
||||
|
@ -224,7 +224,7 @@ void DrawBoxesByScore(ImageView& image, const BoxArray& boxes, int64_t thickness
|
|||
|
||||
void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
const ortc::Tensor<float>& input_box,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
ortc::Tensor<uint8_t>& output) const {
|
||||
// Setup inputs
|
||||
const auto& dimensions_bgr = input_bgr.Shape();
|
||||
|
||||
|
@ -260,4 +260,4 @@ void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -39,7 +39,7 @@ struct DrawBoundingBoxes : BaseKernel {
|
|||
|
||||
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
const ortc::Tensor<float>& input_box,
|
||||
ortc::Tensor<uint8_t>& output);
|
||||
ortc::Tensor<uint8_t>& output) const;
|
||||
|
||||
private:
|
||||
int64_t thickness_;
|
||||
|
|
|
@ -7,8 +7,7 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const {
|
||||
// Setup inputs
|
||||
const auto dimensions_bgr = input.Shape();
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ struct KernelEncodeImage : BaseKernel {
|
|||
}
|
||||
|
||||
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
ortc::Tensor<uint8_t>& output);
|
||||
ortc::Tensor<uint8_t>& output) const;
|
||||
|
||||
private:
|
||||
std::string extension_;
|
||||
|
|
Загрузка…
Ссылка в новой задаче