Make kernel Compute method implementations const (#500)

* Nodes can be called concurrently and Compute needs to be stateless due to that.

Update the kernels to make Compute const.

* Fix test that uses ustring.h.

Would be better to not have duplicate declarations for GetTensorMutableDataString and FillTensorDataString in ustring.h and string_tensor.h.
This commit is contained in:
Scott McKay 2023-07-28 09:25:36 +10:00 коммит произвёл GitHub
Родитель b8bac85ecd
Коммит e448676a5e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
52 изменённых файлов: 186 добавлений и 160 удалений

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_utils.h"
#include "ustring.h"
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output) {
(void)context;
OrtTensorDimensions dimensions(ort, value);
@ -23,7 +23,7 @@ void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKe
}
}
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output) {
(void)ort;
(void)context;
@ -32,11 +32,11 @@ void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelCo
temp[i] = value[i].c_str();
}
OrtW::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size()));
OrtW::ThrowOnError(api, api.FillStringTensor(output, temp.data(), value.size()));
}
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<ustring>& output) {
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const OrtValue* value, std::vector<ustring>& output) {
std::vector<std::string> utf8_strings;
GetTensorMutableDataString(api, ort, context, value, utf8_strings);
@ -46,12 +46,11 @@ void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKe
}
}
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const std::vector<ustring>& value, OrtValue* output) {
std::vector<std::string> utf8_strings;
utf8_strings.reserve(value.size());
for (const auto& str: value) {
for (const auto& str : value) {
utf8_strings.push_back(std::string(str));
}
FillTensorDataString(api, ort, context, utf8_strings, output);

Просмотреть файл

@ -6,10 +6,9 @@
#include "ocos.h"
#include <string>
// Retrieves a vector of strings if the input type is std::string.
// It is a copy of the input data and can be modified to compute the output.
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output);
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output);

Просмотреть файл

@ -45,9 +45,8 @@ struct hash<ustring> {
};
} // namespace std
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const OrtValue* value, std::vector<ustring>& output);
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
const std::vector<ustring>& value, OrtValue* output);

Просмотреть файл

@ -105,9 +105,9 @@ class Tensor : public TensorBase {
type_ = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
const OrtMemoryInfo* mem_info = {};
api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info);
api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
if (mem_info) {
api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_);
api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
}
}
}
@ -856,7 +856,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
template <typename CustomOp>
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
template <typename... Args>
using CustomComputeFn = void (CustomOp::*)(Args...);
using CustomComputeFn = void (CustomOp::*)(Args...) const;
using MyType = OrtLiteCustomStruct<CustomOp>;
struct Kernel {

Просмотреть файл

@ -21,7 +21,6 @@
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
namespace OrtW {
//

Просмотреть файл

@ -38,7 +38,7 @@ struct AudioDecoder : public BaseKernel {
kFLAC
};
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format) {
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format) const {
static const std::map<std::string, AudioStreamType> format_mapping = {
{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
@ -98,7 +98,7 @@ struct AudioDecoder : public BaseKernel {
void Compute(const ortc::Tensor<uint8_t>& input,
const std::optional<std::string> format,
ortc::Tensor<float>& output0) {
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
@ -146,7 +146,7 @@ struct AudioDecoder : public BaseKernel {
}
if (downsample_rate_ != 0 &&
orig_sample_rate < downsample_rate_) {
orig_sample_rate < downsample_rate_) {
ORTX_CXX_API_THROW("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
}

Просмотреть файл

@ -131,7 +131,7 @@ AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& inf
binary_type_ = TryToGetAttributeWithDefault<std::string>(kBinaryType, "");
}
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) {
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
@ -196,7 +196,8 @@ void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::
AzureTextInvoker::AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
}
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) {
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input,
ortc::Tensor<std::string>& output) const {
CurlHandler curl_handler(WriteStringCallback);
StringBuffer string_buffer;
@ -313,8 +314,7 @@ int8_t* CreateNonStrTensor(const std::string& data_type,
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
}
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs,
ortc::Variadic& outputs) {
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);

Просмотреть файл

@ -20,7 +20,7 @@ struct AzureInvoker : public BaseKernel {
struct AzureAudioInvoker : public AzureInvoker {
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output);
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const;
private:
std::string binary_type_;
@ -28,7 +28,7 @@ struct AzureAudioInvoker : public AzureInvoker {
struct AzureTextInvoker : public AzureInvoker {
AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output);
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) const;
private:
std::string binary_type_;
@ -36,7 +36,7 @@ struct AzureTextInvoker : public AzureInvoker {
struct AzureTritonInvoker : public AzureInvoker {
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
private:
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;

Просмотреть файл

@ -18,7 +18,7 @@ struct STFT : public BaseKernel {
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) {
ortc::Tensor<float>& output0) const {
auto X = input0.Data();
auto window = input3.data();
auto dimensions = input0.Shape();
@ -77,7 +77,7 @@ struct StftNormal : public STFT {
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) {
ortc::Tensor<float>& output0) const {
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
}
};

Просмотреть файл

@ -11,6 +11,6 @@ KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& inf
void KernelStringEqual::Compute(OrtKernelContext* context,
const ortc::Tensor<std::string>&,
const ortc::Tensor<std::string>&,
ortc::Tensor<bool>& output) {
ortc::Tensor<bool>& output) const {
KernelEqual_Compute<std::string>(api_, ort_, context);
}

Просмотреть файл

@ -11,5 +11,5 @@ struct KernelStringEqual : BaseKernel {
void Compute(OrtKernelContext* context,
const ortc::Tensor<std::string>&,
const ortc::Tensor<std::string>&,
ortc::Tensor<bool>& output);
ortc::Tensor<bool>& output) const;
};

Просмотреть файл

@ -27,7 +27,8 @@ class BroadcastIteratorRight {
}
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
ORTX_CXX_API_THROW(MakeString(
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]),
ORT_INVALID_ARGUMENT);
}
}
cum_shape2_[shape2_.size() - 1] = 1;
@ -114,7 +115,7 @@ inline bool Compare<std::string>::operator()(const std::string& s1, const std::s
}
template <typename T>
void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
void KernelEqual_Compute(const OrtApi& api, const OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const T* X = ort_.GetTensorData<T>(input_X);
@ -144,7 +145,8 @@ void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelCo
}
template <>
void KernelEqual_Compute<std::string>(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
void KernelEqual_Compute<std::string>(const OrtApi& api, const OrtW::CustomOpApi& ort_,
OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);

Просмотреть файл

@ -5,8 +5,8 @@
#include "op_ragged_tensor.hpp"
void KernelRaggedTensoroSparse::Compute(const ortc::Tensor<int64_t>& n_element,
ortc::Tensor<int64_t>& output_0,
ortc::Tensor<int64_t>& output_1) {
ortc::Tensor<int64_t>& output_0,
ortc::Tensor<int64_t>& output_1) const {
const int64_t* p_n_elements = n_element.Data();
auto& d_length = n_element.Shape();
@ -42,14 +42,15 @@ CommonRaggedTensoroDense::CommonRaggedTensoroDense(const OrtApi& api, const OrtK
: BaseKernel(api, info) {
}
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs,
OrtTensorDimensions* dims) const {
for (int i = 0; i < 4; ++i) {
inputs[i] = ort_.KernelContext_GetInput(context, i);
dims[i] = OrtTensorDimensions(ort_, inputs[i]);
}
}
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) const {
int64_t size = n;
int64_t max_col = 0;
for (int64_t i = 1; i < size; ++i) {
@ -64,10 +65,10 @@ KernelRaggedTensoroDense::KernelRaggedTensoroDense(const OrtApi& api, const OrtK
}
void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<int64_t>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<int64_t>& input3,
ortc::Tensor<int64_t>& output) {
const ortc::Tensor<int64_t>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<int64_t>& input3,
ortc::Tensor<int64_t>& output) const {
const int64_t* p_values = input1.Data();
const int64_t* p_missing = input2.Data();
const int64_t* p_indices = input3.Data();
@ -97,14 +98,15 @@ void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
}
}
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensoroDense(api, info) {
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
: CommonRaggedTensoroDense(api, info) {
}
void KernelStringRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<std::string>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<std::string>& input3,
ortc::Tensor<std::string>& output) {
const ortc::Tensor<std::string>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<std::string>& input3,
ortc::Tensor<std::string>& output) const {
// const OrtValue* inputs[4];
// OrtTensorDimensions dims[4];

Просмотреть файл

@ -11,15 +11,15 @@ struct KernelRaggedTensoroSparse : BaseKernel {
void Compute(const ortc::Tensor<int64_t>& n_element,
ortc::Tensor<int64_t>& output_0,
ortc::Tensor<int64_t>& output_1);
ortc::Tensor<int64_t>& output_1) const;
};
struct CommonRaggedTensoroDense : BaseKernel {
CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
protected:
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
int64_t GetMaxCol(int64_t n, const int64_t* p_indices);
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) const;
int64_t GetMaxCol(int64_t n, const int64_t* p_indices) const;
};
struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
@ -28,7 +28,7 @@ struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
const ortc::Tensor<int64_t>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<int64_t>& input3,
ortc::Tensor<int64_t>& output);
ortc::Tensor<int64_t>& output) const;
private:
int64_t missing_value_;
@ -40,5 +40,5 @@ struct KernelStringRaggedTensoroDense : CommonRaggedTensoroDense {
const ortc::Tensor<std::string>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<std::string>& input3,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
};

Просмотреть файл

@ -10,13 +10,13 @@
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
global_replace_ = TryToGetAttributeWithDefault("global_replace", 1);
}
void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
std::string_view str_rewrite,
ortc::Tensor<std::string>& output) {
std::string_view str_pattern,
std::string_view str_rewrite,
ortc::Tensor<std::string>& output) const {
if (str_pattern.empty())
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
@ -38,4 +38,4 @@ void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
}
}
output.SetStringOutput(str_input, dim);
}
}

Просмотреть файл

@ -11,8 +11,8 @@ struct KernelStringRegexReplace : BaseKernel {
void Compute(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
std::string_view str_rewrite,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
protected:
int64_t global_replace_;
};
};

Просмотреть файл

@ -16,7 +16,7 @@ KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, co
void KernelStringECMARegexReplace::Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view rewrite,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
// make a copy as input is constant;
std::vector<std::string> str_input = input.Data();
if (pattern.empty()) {

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
void Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view rewrite,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
protected:
bool global_replace_;

Просмотреть файл

@ -21,7 +21,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(const ortc::Tensor<std::stri
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
ortc::Tensor<int64_t>& output3) {
ortc::Tensor<int64_t>& output3) const {
// Setup inputs
auto& str_input = input.Data();

Просмотреть файл

@ -16,7 +16,7 @@ struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
ortc::Tensor<int64_t>& output3);
ortc::Tensor<int64_t>& output3) const;
private:
bool ignore_case_;

Просмотреть файл

@ -15,20 +15,23 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo&
auto items = SplitString(line, "\t", true);
if (items.size() != 2) {
ORTX_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW(
"[StringMapping]: Should only exist two items in one line, find error in line: " + std::string(line),
ORT_INVALID_GRAPH);
}
map_[std::string(items[0])] = std::string(items[1]);
}
}
void KernelStringMapping::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
// make a copy as input is constant
std::vector<std::string> input_data = input.Data();
for (auto& str : input_data) {
if (map_.find(str) != map_.end()) {
str = map_[str];
auto entry = map_.find(str);
if (entry != map_.end()) {
str = entry->second;
}
}
output.SetStringOutput(input_data, input.Shape());

Просмотреть файл

@ -10,7 +10,7 @@
struct KernelStringMapping : BaseKernel {
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
private:
std::unordered_map<std::string, std::string> map_;

Просмотреть файл

@ -11,7 +11,7 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(const std::vector<std::string>& str_input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim) {
std::vector<int64_t>& output_dim) const {
std::vector<std::vector<int64_t>> result;
// Set output dimension
@ -42,7 +42,8 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
vector_len_ = ParseVectorLen(lines[0]);
if (vector_len_ == 0) {
ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]),
ORT_INVALID_ARGUMENT);
}
std::vector<int64_t> values(vector_len_);
@ -50,7 +51,8 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
ORT_INVALID_ARGUMENT);
}
ParseValues(kv[1], values);
@ -63,14 +65,17 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
auto unk_strs = SplitString(unk, " ", true);
if (unk_strs.size() != vector_len_) {
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(
MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_),
ORT_INVALID_ARGUMENT);
}
for (auto& str : unk_strs) {
int64_t value;
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
if (end != str.data() + str.size()) {
ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str),
ORT_INVALID_ARGUMENT);
}
unk_value_.push_back(value);
@ -81,7 +86,8 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
ORT_INVALID_ARGUMENT);
}
auto value_strs = SplitString(kv[1], " ", true);
@ -95,7 +101,8 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]),
ORT_INVALID_ARGUMENT);
}
values[i] = value;
}
@ -110,7 +117,7 @@ KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInf
}
void KernelStringToVector::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& out) {
ortc::Tensor<int64_t>& out) const {
// Setup input
auto& input_data = input.Data();
// Get output

Просмотреть файл

@ -14,7 +14,7 @@ class StringToVectorImpl {
StringToVectorImpl(std::string& map, std::string& unk);
std::vector<std::vector<int64_t>> Compute(const std::vector<std::string>& str_input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim);
std::vector<int64_t>& output_dim) const;
private:
void ParseMappingTable(std::string& map);
@ -32,7 +32,7 @@ class StringToVectorImpl {
struct KernelStringToVector : BaseKernel {
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& out);
ortc::Tensor<int64_t>& out) const;
private:
std::shared_ptr<StringToVectorImpl> impl_;

Просмотреть файл

@ -20,7 +20,7 @@ VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk
std::vector<std::string> VectorToStringImpl::Compute(const void* input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim) {
std::vector<int64_t>& output_dim) const {
std::vector<std::string> result;
const int64_t* ptr = static_cast<const int64_t*>(input);
@ -30,7 +30,8 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input,
output_dim = input_dim;
} else {
if (input_dim.empty() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_),
ORT_INVALID_ARGUMENT);
}
output_dim = input_dim;
@ -72,7 +73,8 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
ORT_INVALID_ARGUMENT);
}
ParseValues(kv[1], values);
@ -85,7 +87,8 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) {
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line),
ORT_INVALID_ARGUMENT);
}
auto value_strs = SplitString(kv[1], " ", true);
@ -99,7 +102,8 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]),
ORT_INVALID_ARGUMENT);
}
values[i] = value;
}
@ -114,7 +118,7 @@ KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInf
}
void KernelVectorToString::Compute(const ortc::Tensor<int64_t>& input,
ortc::Tensor<std::string>& out) {
ortc::Tensor<std::string>& out) const {
const void* input_data = input.Data();
std::vector<int64_t> output_dim;

Просмотреть файл

@ -22,7 +22,7 @@ class VectorToStringImpl {
VectorToStringImpl(std::string& map, std::string& unk);
std::vector<std::string> Compute(const void* input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim);
std::vector<int64_t>& output_dim) const;
private:
void ParseMappingTable(std::string& map);
@ -37,7 +37,7 @@ class VectorToStringImpl {
struct KernelVectorToString : BaseKernel {
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<int64_t>& input,
ortc::Tensor<std::string>& out);
ortc::Tensor<std::string>& out) const;
private:
std::shared_ptr<VectorToStringImpl> impl_;

Просмотреть файл

@ -88,11 +88,12 @@ KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInf
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
tokenize_punctuation, remove_control_chars);
}
void KernelBasicTokenizer::Compute(std::string_view input,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
// Setup inputs
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
output.SetStringOutput({result[0].operator std::string()}, {1});

Просмотреть файл

@ -24,7 +24,7 @@ class BasicTokenizer {
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(std::string_view input,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
private:
std::shared_ptr<BasicTokenizer> tokenizer_;

Просмотреть файл

@ -331,7 +331,7 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
// Setup inputs
auto& input_data = input.Data();
@ -343,7 +343,7 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::list<OffsetMappingType> offset_map;
// Only compute offset mapping if optional output for it exists.
compute_offset_mapping = false;
bool compute_offset_mapping = false;
if (offset_mapping.has_value()) {
compute_offset_mapping = true;
}
@ -397,7 +397,7 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
// Setup inputs
auto& input_data = input.Data();
@ -408,7 +408,7 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::list<OffsetMappingType> offset_map;
// Only compute offset mapping if optional output for it exists.
compute_offset_mapping = false;
bool compute_offset_mapping = false;
if (offset_mapping.has_value()) {
compute_offset_mapping = true;
}

Просмотреть файл

@ -46,8 +46,10 @@ class WordpieceTokenizer final {
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word = 100);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
private:
@ -67,7 +69,8 @@ class BertTokenizer final {
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map, bool compute_offset_mapping);
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
void Truncate(std::vector<int64_t>& ids);
@ -98,9 +101,8 @@ struct KernelBertTokenizer : BaseKernel {
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
bool compute_offset_mapping;
protected:
std::unique_ptr<BertTokenizer> tokenizer_;
@ -113,5 +115,5 @@ struct KernelHfBertTokenizer : KernelBertTokenizer {
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
};

Просмотреть файл

@ -138,7 +138,7 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const
void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
const ortc::Tensor<int64_t>& positions,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
auto& ids_dim = ids.Shape();
@ -151,7 +151,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
if (use_indices_ &&
(!((positions.NumberOfElement() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices",
ORT_INVALID_GRAPH);
}
const int64_t* p_positions = positions.NumberOfElement() == 0 ? nullptr : positions.Data();
@ -159,7 +160,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
std::vector<std::string> result;
std::vector<int64_t> output_dim(1);
if (!use_indices_) {
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()), skip_special_tokens_, clean_up_tokenization_spaces_));
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()),
skip_special_tokens_, clean_up_tokenization_spaces_));
output_dim[0] = 1;
} else {
if (p_positions != nullptr) {
@ -167,7 +169,8 @@ void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
int64_t start = p_positions[2 * i];
int64_t end = p_positions[2 * i + 1];
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end),
skip_special_tokens_, clean_up_tokenization_spaces_));
}
output_dim[0] = positions_dim[0];
}

Просмотреть файл

@ -33,7 +33,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<int64_t>& ids,
const ortc::Tensor<int64_t>& positions,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
private:
std::shared_ptr<BertTokenizerDecoder> decoder_;

Просмотреть файл

@ -16,7 +16,8 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()),
static_cast<int>(model_data_.size()));
if (model_ptr == nullptr) {
ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
@ -28,11 +29,13 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
}
void KernelBlingFireSentenceBreaker::Compute(std::string_view input,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
int max_length = static_cast<int>(2 * input.size() + 1);
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()),
output_str.get(), nullptr, nullptr, max_length,
model_.get());
if (output_length < 0) {
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input, "\" failed"), ORT_INVALID_ARGUMENT);
}

Просмотреть файл

@ -18,7 +18,7 @@ extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(std::string_view input,
ortc::Tensor<std::string>& output);
ortc::Tensor<std::string>& output) const;
private:
using ModelPtr = std::shared_ptr<void>;

Просмотреть файл

@ -100,7 +100,7 @@ struct KernelBpeDecoder : public BaseKernel {
}
void Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
@ -181,4 +181,4 @@ struct KernelBpeDecoder : public BaseKernel {
std::map<char32_t, unsigned char> byte_decoder_;
std::map<int64_t, std::string> added_tokens_;
std::set<int64_t> all_special_ids_;
};
};

Просмотреть файл

@ -31,8 +31,10 @@ KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKerne
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
}
std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map) {
std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const {
std::vector<int64_t> res;
std::list<std::pair<int, int>> byte_list;
WhiteSpaceClean(input);
@ -87,21 +89,21 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
utf8_token.erase(std::remove(utf8_token.begin(), utf8_token.end(), ' '), utf8_token.end());
// Get byte encodings prior to performing BPE
byte_list_.clear();
byte_list.clear();
for (int i = 0; i < utf8_token.length(); i++) {
if (i == utf8_token.length() - 1) {
std::string boundary(1, utf8_token[i]);
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->GetEncoding(boundary + "</w>"), 1));
byte_list.push_back(std::make_pair(bbpe_tokenizer_->GetEncoding(boundary + "</w>"), 1));
} else {
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])], 1));
byte_list.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])], 1));
}
}
// Perform BPE
bbpe_tokenizer_->bpe(byte_list_);
bbpe_tokenizer_->bpe(byte_list);
// Add output to result
for (auto p : byte_list_) {
for (auto p : byte_list) {
if (static_cast<int64_t>(res.size()) >= max_length) {
break;
}
@ -131,7 +133,7 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
// Setup inputs
std::vector<std::string> str_input{input.Data()};
std::list<OffsetMappingType> offset_map;
@ -140,14 +142,15 @@ void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::vector<std::vector<int64_t>> tokenize_results;
// Only compute offset mapping if optional output for it exists.
compute_offset_mapping = false;
bool compute_offset_mapping = false;
if (offset_mapping.has_value()) {
compute_offset_mapping = true;
}
for (auto& str : str_input) {
ustring ustr = ustring(str);
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_, offset_map));
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_,
compute_offset_mapping, offset_map));
}
size_t max_length = 0;

Просмотреть файл

@ -9,14 +9,13 @@ struct KernelClipBpeTokenizer : BaseKernel {
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
bool compute_offset_mapping;
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
private:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map);
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const;
int64_t padding_length_;
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};

Просмотреть файл

@ -30,8 +30,9 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& i
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
}
std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t max_length) {
std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t max_length) const {
std::vector<int64_t> res;
std::list<std::pair<int, int>> byte_list;
if (IsEmptyUString(input)) {
return res;
@ -59,14 +60,14 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
std::string utf8_token = std::string(ustring(tok));
byte_list_.clear();
byte_list.clear();
for (char& cp : utf8_token) {
byte_list_.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
byte_list.push_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
}
bbpe_tokenizer_->bpe(byte_list_);
bbpe_tokenizer_->bpe(byte_list);
for (auto p : byte_list_) {
for (auto p : byte_list) {
if (static_cast<int64_t>(res.size()) >= max_length) {
break;
}
@ -81,7 +82,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask) {
std::optional<ortc::Tensor<int64_t>*> attention_mask) const {
// Setup inputs
std::vector<std::string> str_input{input.Data()};
const auto& input_dim = input.Shape();

Просмотреть файл

@ -8,12 +8,11 @@ struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask);
std::optional<ortc::Tensor<int64_t>*> attention_mask) const;
private:
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length);
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length) const;
int64_t padding_length_;
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};

Просмотреть файл

@ -31,8 +31,11 @@ KernelRobertaBpeTokenizer::KernelRobertaBpeTokenizer(const OrtApi& api, const Or
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
}
std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map) {
std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t max_length,
bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const {
std::vector<int64_t> res;
std::list<std::pair<int, int>> byte_list;
if (IsEmptyUString(input)) {
return res;
@ -81,16 +84,16 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
}
// Get byte encodings prior to performing BPE
byte_list_.clear();
byte_list.clear();
for (char& cp : utf8_token) {
byte_list_.emplace_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
byte_list.emplace_back(std::make_pair(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(cp)], 1));
}
// Perform BPE
bbpe_tokenizer_->bpe(byte_list_);
bbpe_tokenizer_->bpe(byte_list);
// Add output to result
for (auto p : byte_list_) {
for (auto p : byte_list) {
if (static_cast<int64_t>(res.size()) >= max_length) {
break;
}
@ -119,7 +122,7 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
// Setup inputs
std::vector<std::string> str_input{input.Data()};
std::list<OffsetMappingType> offset_map;
@ -128,14 +131,15 @@ void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::vector<std::vector<int64_t>> tokenize_results;
// Only compute offset mapping if optional output for it exists.
compute_offset_mapping = false;
bool compute_offset_mapping = false;
if (offset_mapping.has_value()) {
compute_offset_mapping = true;
}
for (auto& str : str_input) {
ustring ustr = ustring(str);
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_, offset_map));
tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_,
compute_offset_mapping, offset_map));
}
size_t max_length = 0;

Просмотреть файл

@ -9,14 +9,13 @@ struct KernelRobertaBpeTokenizer : BaseKernel {
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
bool compute_offset_mapping;
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
private:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, std::list<OffsetMappingType>& offset_map);
std::vector<int64_t> Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const;
int64_t padding_length_;
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};

Просмотреть файл

@ -23,7 +23,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
}
void Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) {
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
auto& ids_dim = ids.Shape();

Просмотреть файл

@ -31,7 +31,7 @@ void KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::string>& inpu
bool add_eos,
bool add_rev,
ortc::Tensor<int32_t>& output,
ortc::Tensor<int64_t>& output1) {
ortc::Tensor<int64_t>& output1) const {
// Update with the new API
auto& str_input = input.Data();
// computation

Просмотреть файл

@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
bool add_eos,
bool add_rev,
ortc::Tensor<int32_t>& output,
ortc::Tensor<int64_t>& output1);
ortc::Tensor<int64_t>& output1) const;
private:
sentencepiece::SentencePieceProcessor tokenizer_;

Просмотреть файл

@ -127,7 +127,7 @@ void KernelWordpieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output,
ortc::Tensor<int64_t>& row_lengths,
ortc::Tensor<int64_t>& out_row_begin,
ortc::Tensor<int64_t>& output_limit_values) {
ortc::Tensor<int64_t>& output_limit_values) const {
// Update with the new API
// make a copy as we need ustring
std::vector<ustring> str_input;

Просмотреть файл

@ -17,7 +17,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
ortc::Tensor<std::string>& output,
ortc::Tensor<int64_t>& row_lengths,
ortc::Tensor<int64_t>& out_row_begin,
ortc::Tensor<int64_t>& output_limit_values);
ortc::Tensor<int64_t>& output_limit_values) const;
private:
int64_t max_input_chars_per_word_;

Просмотреть файл

@ -8,8 +8,7 @@
namespace ort_extensions {
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const {
// Setup inputs
const auto& dimensions = input.Shape();
if (dimensions.size() != 1ULL) {

Просмотреть файл

@ -15,7 +15,7 @@ void decode_image(const ortc::Tensor<uint8_t>& input,
struct KernelDecodeImage : BaseKernel {
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output);
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const;
};
} // namespace ort_extensions

Просмотреть файл

@ -153,7 +153,7 @@ void DrawBox(ImageView& image, gsl::span<const float> box, BoundingBoxFormat bbo
if (thickness < 1) {
return;
}
// If not all filled
if (thickness != (std::min(x_end - x_start, y_end - y_start))) {
auto offset = thickness / 2;
@ -224,7 +224,7 @@ void DrawBoxesByScore(ImageView& image, const BoxArray& boxes, int64_t thickness
void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
const ortc::Tensor<float>& input_box,
ortc::Tensor<uint8_t>& output) {
ortc::Tensor<uint8_t>& output) const {
// Setup inputs
const auto& dimensions_bgr = input_bgr.Shape();
@ -260,4 +260,4 @@ void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
}
}
} // namespace ort_extensions
} // namespace ort_extensions

Просмотреть файл

@ -39,7 +39,7 @@ struct DrawBoundingBoxes : BaseKernel {
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
const ortc::Tensor<float>& input_box,
ortc::Tensor<uint8_t>& output);
ortc::Tensor<uint8_t>& output) const;
private:
int64_t thickness_;

Просмотреть файл

@ -7,8 +7,7 @@
namespace ort_extensions {
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) const {
// Setup inputs
const auto dimensions_bgr = input.Shape();

Просмотреть файл

@ -21,7 +21,7 @@ struct KernelEncodeImage : BaseKernel {
}
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
ortc::Tensor<uint8_t>& output);
ortc::Tensor<uint8_t>& output) const;
private:
std::string extension_;