Partner team's code security fixings (#300)
This commit is contained in:
Родитель
72790873e5
Коммит
7fc0224410
|
@ -70,19 +70,25 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
template <typename... Args>
|
||||
class CuopContainer {
|
||||
public:
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
#endif
|
||||
CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
|
||||
ocos_list_.push_back(nullptr);
|
||||
}
|
||||
|
||||
~CuopContainer() {
|
||||
// skip the last null pointer.
|
||||
for (auto i = 0; i < ocos_list_.size() - 1; i++) {
|
||||
delete ocos_list_[i];
|
||||
if (0 < ocos_list_.size()) {
|
||||
for (size_t i = 0; i < ocos_list_.size() - 1; i++) {
|
||||
delete ocos_list_[i];
|
||||
}
|
||||
}
|
||||
|
||||
ocos_list_.clear();
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
const OrtCustomOp** GetList() {
|
||||
return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
|
||||
}
|
||||
|
|
|
@ -32,16 +32,18 @@ struct KernelGaussianBlur : BaseKernel {
|
|||
|
||||
OrtTensorDimensions input_data_dimensions(ort_, input_data);
|
||||
|
||||
int n = input_data_dimensions[0];
|
||||
int h = input_data_dimensions[1];
|
||||
int w = input_data_dimensions[2];
|
||||
int c = input_data_dimensions[3];
|
||||
int n = static_cast<int>(input_data_dimensions[0]);
|
||||
int h = static_cast<int>(input_data_dimensions[1]);
|
||||
int w = static_cast<int>(input_data_dimensions[2]);
|
||||
int c = static_cast<int>(input_data_dimensions[3]);
|
||||
(void)n;
|
||||
(void)c;
|
||||
|
||||
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
|
||||
cv::Mat output_image;
|
||||
cv::GaussianBlur(input_image,
|
||||
output_image,
|
||||
cv::Size(ksize[1], ksize[0]),
|
||||
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
|
||||
sigma[0], sigma[1], cv::BORDER_DEFAULT);
|
||||
|
||||
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
|
||||
|
@ -70,7 +72,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
|
|||
}
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
|||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
|||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -48,7 +48,7 @@ struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
|||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
|||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -55,6 +55,6 @@ const char* CustomOpSegmentExtraction::GetName() const {
|
|||
return "SegmentExtraction";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
|
@ -30,7 +30,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
|||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
T* p_output = ort_.GetTensorMutableData<T>(v);
|
||||
int64_t out_size = dim_out.Size();
|
||||
memset(p_output, 0, out_size * sizeof(T));
|
||||
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(T)));
|
||||
|
||||
// The implementation is naive. It could be parallelized and
|
||||
// use SIMD instructions to be faster.
|
||||
|
|
|
@ -38,7 +38,7 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
|||
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
|
||||
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
data_ptr[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output) {
|
||||
(void)context;
|
||||
OrtTensorDimensions dimensions(ort, value);
|
||||
size_t len = static_cast<size_t>(dimensions.Size());
|
||||
size_t data_len;
|
||||
|
@ -15,14 +16,16 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
|
|||
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||
output.resize(len);
|
||||
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
||||
if (i < len - 1)
|
||||
result[offsets[i + (int64_t)1]] = '\0';
|
||||
output[i] = result.data() + offsets[i];
|
||||
if (i < static_cast<int64_t>(len) - 1)
|
||||
result[offsets[static_cast<size_t>(i + (int64_t)1)]] = '\0';
|
||||
output[static_cast<size_t>(i)] = result.data() + offsets[static_cast<size_t>(i)];
|
||||
}
|
||||
}
|
||||
|
||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const std::vector<std::string>& value, OrtValue* output) {
|
||||
(void)ort;
|
||||
(void)context;
|
||||
std::vector<const char*> temp(value.size());
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
temp[i] = value[i].c_str();
|
||||
|
|
|
@ -14,7 +14,7 @@ inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
|||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
for (size_t i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions
|
|||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
for (size_t i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
std::vector<std::string> result;
|
||||
std::vector<int64_t> result_dimension;
|
||||
|
||||
for (int i = 0; i < value.size(); i++) {
|
||||
for (size_t i = 0; i < value.size(); i++) {
|
||||
if (!mask[i]) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ class BroadcastIteratorRight {
|
|||
public:
|
||||
BroadcastIteratorRight(const std::vector<int64_t>& shape1,
|
||||
const std::vector<int64_t>& shape2,
|
||||
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
|
||||
const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) {
|
||||
if (shape2.size() > shape1.size())
|
||||
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
|
||||
shape2_.resize(shape1_.size());
|
||||
|
@ -82,7 +82,7 @@ class BroadcastIteratorRight {
|
|||
}
|
||||
|
||||
template <typename TCMP>
|
||||
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
|
||||
void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) {
|
||||
if (pos != 0)
|
||||
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
|
||||
while (!end()) {
|
||||
|
|
|
@ -155,7 +155,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> shape_out{size - 1, max_col};
|
||||
|
||||
int64_t shape_out_size = shape_out[0] * shape_out[1];
|
||||
std::vector<std::string> dense(max_col * (size - 1));
|
||||
std::vector<std::string> dense(static_cast<size_t>(max_col * (size - 1)));
|
||||
int64_t pos = 0;
|
||||
int64_t j, pos_end;
|
||||
for (int64_t i = 0; i < size - 1; ++i) {
|
||||
|
@ -165,7 +165,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||
dense[pos] = input[j];
|
||||
dense[static_cast<size_t>(pos)] = input[static_cast<size_t>(j)];
|
||||
}
|
||||
pos = pos_end;
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
re2::StringPiece piece(str_rewrite[0]);
|
||||
|
@ -50,11 +50,11 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
|
||||
// Do computation
|
||||
if (global_replace_) {
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
re2::RE2::Replace(&(str_input[i]), reg, piece);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
std::vector<std::string_view> tokens;
|
||||
std::vector<int64_t> begin_offsets;
|
||||
std::vector<int64_t> end_offsets;
|
||||
RegexSplitImpl(str_input[i], reg,
|
||||
RegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
|
||||
include_delimiter, keep_reg,
|
||||
tokens, begin_offsets, end_offsets);
|
||||
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
|
||||
|
|
|
@ -29,7 +29,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
|||
GetTensorMutableDataString(api_, ort_, context, right, right_value);
|
||||
|
||||
// reuse left_value as output to save memory
|
||||
for (int i = 0; i < left_value.size(); i++) {
|
||||
for (size_t i = 0; i < left_value.size(); i++) {
|
||||
left_value[i].append(right_value[i]);
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
|||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
auto regex_flag = std::regex_constants::optimize | std::regex_constants::ECMAScript;
|
||||
|
@ -58,11 +58,11 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
|||
std::regex reg(str_pattern[0], regex_flag);
|
||||
|
||||
if (global_replace_) {
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0]);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0], std::regex_constants::format_first_only);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
std::vector<std::string_view> tokens;
|
||||
std::vector<int64_t> begin_offsets;
|
||||
std::vector<int64_t> end_offsets;
|
||||
ECMARegexSplitImpl(str_input[i], reg,
|
||||
ECMARegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
|
||||
include_delimiter, keep_reg,
|
||||
tokens, begin_offsets, end_offsets);
|
||||
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
|
||||
|
|
|
@ -32,9 +32,9 @@ void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
|
|||
std::vector<T>& end_offsets) {
|
||||
size_t prev_pos = 0;
|
||||
for (auto it = std::sregex_iterator(input.begin(), input.end(), pattern); it != std::sregex_iterator(); it++) {
|
||||
size_t cur_pos = it->position();
|
||||
size_t matched_length = it->length();
|
||||
if (prev_pos != it->position()) {
|
||||
int cur_pos = static_cast<int>(it->position());
|
||||
int matched_length = static_cast<int>(it->length());
|
||||
if (static_cast<decltype(it->position())>(prev_pos) != it->position()) {
|
||||
tokens.emplace_back(input.c_str() + prev_pos, cur_pos - prev_pos);
|
||||
begin_offsets.push_back(prev_pos);
|
||||
end_offsets.push_back(cur_pos);
|
||||
|
|
|
@ -33,12 +33,12 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
|||
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
size_t nb = static_cast<size_t>(*p_num_buckets);
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out[i] = static_cast<int64_t>(Hash64(str_input[i].c_str(), str_input[i].size()) % nb);
|
||||
}
|
||||
}
|
||||
|
@ -92,12 +92,12 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
|||
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
size_t nb = static_cast<size_t>(*p_num_buckets);
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out[i] = static_cast<int64_t>(util::Fingerprint64(str_input[i].c_str(), str_input[i].size()) % nb);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
if (X.size() != 1)
|
||||
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
|
||||
} else {
|
||||
if (*axis < 0 || *axis >= dimensions.size())
|
||||
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
|
||||
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
@ -49,18 +49,18 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
std::vector<std::string> out(size);
|
||||
std::vector<std::string> out(static_cast<size_t>(size));
|
||||
|
||||
if (dimensions.size() > 0) {
|
||||
if (X.size() > 0) {
|
||||
// Do computation
|
||||
int64_t h = 1;
|
||||
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
|
||||
for (size_t i = static_cast<size_t>(*axis + 1); i < dimensions.size(); ++i) {
|
||||
h *= dimensions[i];
|
||||
}
|
||||
int64_t left_part = size / h;
|
||||
int64_t right_part = size / left_part;
|
||||
int64_t n_red = dimensions[*axis] - 1;
|
||||
int64_t n_red = dimensions[static_cast<size_t>(*axis)] - 1;
|
||||
int64_t inc = right_part * (n_red + 1);
|
||||
int64_t pos = 0;
|
||||
for (int64_t li = 0; li < left_part; ++li) {
|
||||
|
@ -68,10 +68,10 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[index] << sep[0];
|
||||
st << X[static_cast<size_t>(index)] << sep[0];
|
||||
}
|
||||
st << X[index];
|
||||
out[pos] = st.str();
|
||||
st << X[static_cast<size_t>(index)];
|
||||
out[static_cast<size_t>(pos)] = st.str();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -16,7 +16,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
|
|||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) {return static_cast<char>(ToLower(c));});
|
||||
}
|
||||
|
||||
|
|
|
@ -36,10 +36,10 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
if (delimiter.size() == 0) {
|
||||
char word[2] = "a";
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
const std::string& str = X[static_cast<size_t>(row)];
|
||||
if (str.empty())
|
||||
continue;
|
||||
maxc = str.size() > maxc ? str.size() : maxc;
|
||||
maxc = static_cast<int64_t>(static_cast<int64_t>(str.size()) > maxc ? str.size() : maxc);
|
||||
for (auto it = str.begin(); it != str.end(); ++it) {
|
||||
word[0] = *it;
|
||||
words.push_back(word);
|
||||
|
@ -51,7 +51,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
bool keep = !(*skip_empty);
|
||||
std::size_t current, previous = 0;
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
const std::string& str = X[static_cast<size_t>(row)];
|
||||
if (str.empty())
|
||||
continue;
|
||||
previous = 0;
|
||||
|
|
|
@ -91,7 +91,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
std::vector<std::string_view> value_strs = SplitString(v, " ", true);
|
||||
|
||||
int64_t value;
|
||||
for (int i = 0; i < value_strs.size(); i++) {
|
||||
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
|
|
|
@ -16,8 +16,8 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
|
|||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), ::toupper);
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c){ return static_cast<char>(::toupper(c)); });
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
|
|
|
@ -28,7 +28,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
// only hit when the key is a scalar and the input is a vector
|
||||
output_dim = input_dim;
|
||||
} else {
|
||||
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != vector_len_) {
|
||||
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
@ -37,9 +37,9 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
}
|
||||
|
||||
std::vector<int64_t> key(vector_len_);
|
||||
for (int i = 0; i < input_dim.Size(); i += vector_len_) {
|
||||
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
|
||||
//construct key
|
||||
for (int j = 0; j < vector_len_; j++) {
|
||||
for (size_t j = 0; j < vector_len_; j++) {
|
||||
key[j] = ptr[j];
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
std::vector<std::string_view> value_strs = SplitString(v, " ", true);
|
||||
|
||||
int64_t value;
|
||||
for (int i = 0; i < value_strs.size(); i++) {
|
||||
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||
|
|
|
@ -10,8 +10,12 @@
|
|||
#include <algorithm>
|
||||
|
||||
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
|
||||
do_lower_case_(do_lower_case), tokenize_chinese_chars_(tokenize_chinese_chars), strip_accents_(strip_accents), tokenize_punctuation_(tokenize_punctuation),
|
||||
remove_control_chars_(remove_control_chars){}
|
||||
do_lower_case_(do_lower_case),
|
||||
strip_accents_(strip_accents),
|
||||
tokenize_chinese_chars_(tokenize_chinese_chars),
|
||||
tokenize_punctuation_(tokenize_punctuation),
|
||||
remove_control_chars_(remove_control_chars)
|
||||
{}
|
||||
|
||||
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||
std::vector<ustring> result;
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
BertTokenizerVocab::BertTokenizerVocab(std::string_view vocab) : raw_vocab_(vocab) {
|
||||
auto tokens = SplitString(raw_vocab_, "\r\n", true);
|
||||
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
(vocab_)[tokens[i]] = i;
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
(vocab_)[tokens[i]] = static_cast<int32_t>(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,11 +40,16 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
|||
}
|
||||
|
||||
WordpieceTokenizer::WordpieceTokenizer(
|
||||
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word)
|
||||
: vocab_(std::move(vocab)), unk_token_(std::move(unk_token))
|
||||
, suffix_indicator_(std::move(suffix_indicator))
|
||||
, max_input_chars_per_word_(max_input_chars_per_word) {
|
||||
std::shared_ptr<BertTokenizerVocab> vocab,
|
||||
ustring unk_token,
|
||||
ustring suffix_indicator,
|
||||
int max_input_chars_per_word
|
||||
) :
|
||||
max_input_chars_per_word_(max_input_chars_per_word),
|
||||
suffix_indicator_(std::move(suffix_indicator)),
|
||||
unk_token_(std::move(unk_token)),
|
||||
vocab_(std::move(vocab))
|
||||
{
|
||||
unk_token_id_ = vocab_->FindTokenId(unk_token_);
|
||||
}
|
||||
|
||||
|
@ -92,7 +97,7 @@ std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& toke
|
|||
}
|
||||
|
||||
void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
|
||||
if (token.size() > max_input_chars_per_word_) {
|
||||
if (static_cast<int64_t>(token.size()) > max_input_chars_per_word_) {
|
||||
tokenized_result.push_back(unk_token_);
|
||||
return;
|
||||
}
|
||||
|
@ -127,13 +132,13 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
|
|||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int32_t max_len) {
|
||||
if ((max_len > 0) && (max_len < ids.size())) {
|
||||
if ((max_len > 0) && (static_cast<size_t>(max_len) < ids.size())) {
|
||||
ids.resize(max_len);
|
||||
}
|
||||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len) {
|
||||
if (max_len < 0 || (ids1.size() + ids2.size() <= max_len)) {
|
||||
if (max_len < 0 || (ids1.size() + ids2.size() <= static_cast<size_t>(max_len))) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -145,7 +150,7 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>
|
|||
case TruncateStrategyType::LONGEST_FIRST:
|
||||
case TruncateStrategyType::LONGEST_FROM_BACK:
|
||||
|
||||
if ((ids1_keep_len > half_max_len) && (ids2_keep_len > half_max_len)) {
|
||||
if ((ids1_keep_len > static_cast<size_t>(half_max_len)) && (ids2_keep_len > static_cast<size_t>(half_max_len))) {
|
||||
ids1_keep_len = static_cast<size_t>(max_len) - half_max_len;
|
||||
ids2_keep_len = half_max_len;
|
||||
} else if (ids2_keep_len > ids1_keep_len) {
|
||||
|
@ -173,12 +178,24 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>
|
|||
}
|
||||
|
||||
BertTokenizer::BertTokenizer(
|
||||
const std::string& vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token,
|
||||
ustring sep_token, ustring pad_token, ustring cls_token, ustring mask_token,
|
||||
bool tokenize_chinese_chars, bool strip_accents, ustring suffix_indicator, int32_t max_len,
|
||||
const std::string& truncation_strategy)
|
||||
: do_basic_tokenize_(do_basic_tokenize), max_length_(max_len)
|
||||
, truncate_(std::make_unique<TruncateStrategy>(truncation_strategy)) {
|
||||
const std::string& vocab,
|
||||
bool do_lower_case,
|
||||
bool do_basic_tokenize,
|
||||
ustring unk_token,
|
||||
ustring sep_token,
|
||||
ustring pad_token,
|
||||
ustring cls_token,
|
||||
ustring mask_token,
|
||||
bool tokenize_chinese_chars,
|
||||
bool strip_accents,
|
||||
ustring suffix_indicator,
|
||||
int32_t max_len,
|
||||
const std::string& truncation_strategy
|
||||
) :
|
||||
max_length_(max_len),
|
||||
do_basic_tokenize_(do_basic_tokenize),
|
||||
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy))
|
||||
{
|
||||
|
||||
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
|
||||
|
||||
|
|
|
@ -1,25 +1,36 @@
|
|||
#include "bert_tokenizer_decoder.hpp"
|
||||
|
||||
BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, std::string unk_token, std::string sep_token, std::string pad_token,
|
||||
std::string cls_token,std::string mask_token,std::string suffix_indicator) : raw_vocab_(vocab), unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
|
||||
BertTokenizerDecoder::BertTokenizerDecoder(
|
||||
std::string vocab,
|
||||
std::string unk_token,
|
||||
std::string sep_token,
|
||||
std::string pad_token,
|
||||
std::string cls_token,
|
||||
std::string mask_token,
|
||||
std::string suffix_indicator
|
||||
) :
|
||||
unk_token_(unk_token),
|
||||
suffix_indicator_(suffix_indicator),
|
||||
raw_vocab_(vocab)
|
||||
{
|
||||
auto tokens = SplitString(raw_vocab_, "\n", true);
|
||||
vocab_.reserve(tokens.size());
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
auto& token = tokens[i];
|
||||
if (token == unk_token) {
|
||||
unk_token_id_ = i;
|
||||
unk_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
if (token == sep_token) {
|
||||
sep_token_id_ = i;
|
||||
sep_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
if (token == pad_token) {
|
||||
sep_token_id_ = i;
|
||||
sep_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
if (token == cls_token) {
|
||||
cls_token_id_ = i;
|
||||
cls_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
if (token == mask_token) {
|
||||
mask_token_id_ = i;
|
||||
mask_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
|
||||
if (token.rfind(suffix_indicator_, 0) == 0) {
|
||||
|
@ -42,7 +53,7 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
|
|||
}
|
||||
|
||||
// deal with unk ids
|
||||
if (id >= vocab_.size() || id < 0) {
|
||||
if (id < 0 || static_cast<size_t>(id) >= vocab_.size()) {
|
||||
if (!result.empty()) {
|
||||
result.push_back(' ');
|
||||
}
|
||||
|
@ -51,7 +62,7 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
|
|||
}
|
||||
|
||||
// skip first substr
|
||||
if (result.empty() && is_substr_[id]) {
|
||||
if (result.empty() && is_substr_[static_cast<size_t>(id)]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -59,11 +70,11 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
|
|||
// we needn't add a space at the beginning of the output
|
||||
// we needn't add a space when the token is a substr (such as ##ing)
|
||||
// we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
|
||||
if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
|
||||
if (!(result.empty() || is_substr_[static_cast<size_t>(id)] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
|
||||
result.push_back(' ');
|
||||
}
|
||||
|
||||
result.append(vocab_[id]);
|
||||
result.append(vocab_[static_cast<size_t>(id)]);
|
||||
pre_token = id;
|
||||
}
|
||||
|
||||
|
@ -76,8 +87,8 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
|
|||
return true;
|
||||
}
|
||||
|
||||
auto pre_char = ustring(vocab_[pre_token_id]).back();
|
||||
auto cur_char = ustring(vocab_[new_token_id])[0];
|
||||
auto pre_char = ustring(vocab_[static_cast<size_t>(pre_token_id)]).back();
|
||||
auto cur_char = ustring(vocab_[static_cast<size_t>(new_token_id)])[0];
|
||||
|
||||
// normal punctuation
|
||||
if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {
|
||||
|
|
|
@ -15,7 +15,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
|
||||
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
|
||||
|
||||
if (model_ptr == nullptr) {
|
||||
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
||||
|
@ -24,7 +24,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||
|
||||
if (HasAttribute("max_sentence")) {
|
||||
max_sentence = ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence");
|
||||
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,10 +42,10 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
std::string& input_string = input_data[0];
|
||||
int max_length = 2 * input_string.size() + 1;
|
||||
int max_length = static_cast<int>(2 * input_string.size() + 1);
|
||||
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
|
||||
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||
if (output_length < 0) {
|
||||
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
|
|
@ -167,8 +167,8 @@ class VocabData {
|
|||
line = std::regex_replace(line, std::regex("\r"), "");
|
||||
ustring line_32(line);
|
||||
int id = static_cast<int>(vocab_map_.size());
|
||||
if (auto it = vocab_map_.find(line); it != vocab_map_.end()) {
|
||||
id = it->second;
|
||||
if (auto nestedIt = vocab_map_.find(line); nestedIt != vocab_map_.end()) {
|
||||
id = nestedIt->second;
|
||||
} else {
|
||||
vocab_map_[line] = id;
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ class VocabData {
|
|||
}
|
||||
|
||||
const std::string& IdToToken(int id) const {
|
||||
if ((id < 0) || (id >= id2token_map_.size())) {
|
||||
if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) {
|
||||
ORT_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
return id2token_map_[id];
|
||||
|
@ -500,7 +500,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
TokenWithRegularExp regcmp;
|
||||
|
||||
for (auto& seg_id : special_token_split_res) {
|
||||
if (res.size() >= max_length) break;
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) break;
|
||||
|
||||
if (seg_id.second != -1) {
|
||||
res.push_back(seg_id.second);
|
||||
|
@ -512,7 +512,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
const char32_t* ptr = cur_input.c_str();
|
||||
regcmp.Set(ptr);
|
||||
|
||||
while (res.size() < max_length) {
|
||||
while (static_cast<int64_t>(res.size()) < max_length) {
|
||||
auto [b, tok] = regcmp.GetNextToken();
|
||||
if (!b) break;
|
||||
|
||||
|
@ -526,7 +526,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
bbpe_tokenizer_->bpe(byte_list_);
|
||||
|
||||
for (auto p : byte_list_) {
|
||||
if (res.size() >= max_length) {
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -556,7 +556,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
max_length = std::max(max_length, res.size());
|
||||
}
|
||||
} else {
|
||||
max_length = padding_length_;
|
||||
max_length = static_cast<size_t>(padding_length_);
|
||||
}
|
||||
|
||||
OrtTensorDimensions output_dim = input_dim;
|
||||
|
@ -574,7 +574,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
idx++;
|
||||
}
|
||||
|
||||
for (int i = res.size(); i < max_length; i++) {
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
token[idx] = 0;
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
|
@ -594,14 +594,14 @@ size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
|
|||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
|
|||
sentencepiece::ModelProto model_proto;
|
||||
std::vector<uint8_t> model_as_bytes;
|
||||
if (base64_decode(model_as_string, model_as_bytes)) {
|
||||
model_proto.ParseFromArray(model_as_bytes.data(), model_as_bytes.size());
|
||||
model_proto.ParseFromArray(model_as_bytes.data(), static_cast<int>(model_as_bytes.size()));
|
||||
} else {
|
||||
model_proto.ParseFromArray(model_as_string.c_str(), model_as_string.size());
|
||||
model_proto.ParseFromArray(model_as_string.c_str(), static_cast<int>(model_as_string.size()));
|
||||
}
|
||||
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
|
||||
if (!status.ok())
|
||||
|
@ -46,6 +46,9 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
const OrtValue* ort_add_rev = ort_.KernelContext_GetInput(context, 5);
|
||||
const bool* p_add_rev = ort_.GetTensorData<bool>(ort_add_rev);
|
||||
|
||||
(void)p_nbest_size;
|
||||
(void)p_alpha;
|
||||
|
||||
// Verifications
|
||||
_check_dimension_constant(ort_, ort_nbest_size, "nbest_size");
|
||||
_check_dimension_constant(ort_, ort_alpha, "alpha");
|
||||
|
|
|
@ -24,7 +24,7 @@ KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtK
|
|||
}
|
||||
}
|
||||
|
||||
void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator,
|
||||
void KernelWordpieceTokenizer_Split(const std::u32string& /*suffix_indicator*/,
|
||||
const std::u32string& text,
|
||||
std::vector<std::u32string>& words) {
|
||||
ustring space(" ");
|
||||
|
@ -81,7 +81,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
KernelWordpieceTokenizer_Split(suffix_indicator, *it, words);
|
||||
|
||||
for (auto itk = words.begin(); itk != words.end(); ++itk) {
|
||||
if (itk->size() > max_input_chars_per_word) {
|
||||
if (static_cast<int64_t>(itk->size()) > max_input_chars_per_word) {
|
||||
indices.push_back(-1);
|
||||
tokens.push_back(unk_token);
|
||||
continue;
|
||||
|
@ -107,7 +107,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
is_bad = true;
|
||||
break;
|
||||
}
|
||||
indices.push_back(cur_substr);
|
||||
indices.push_back(static_cast<int32_t>(cur_substr));
|
||||
tokens.push_back(ustring(token));
|
||||
start = end;
|
||||
}
|
||||
|
@ -153,13 +153,13 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
|
||||
int64_t i;
|
||||
for (i = 0; i < size_row_lengths[0]; ++i) {
|
||||
ptr_row_lengths[i] = row_begins[i];
|
||||
ptr_row_begins[i] = row_begins[i];
|
||||
ptr_limit_values[i] = row_begins[i + 1];
|
||||
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
|
||||
ptr_row_begins[i] = row_begins[static_cast<size_t>(i)];
|
||||
ptr_limit_values[i] = row_begins[static_cast<size_t>(i + 1)];
|
||||
}
|
||||
|
||||
i = size_row_lengths[0];
|
||||
ptr_row_lengths[i] = row_begins[i];
|
||||
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
|
||||
}
|
||||
|
||||
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
|
|
|
@ -69,12 +69,12 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
OrtStatus* status = nullptr;
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
if (status = RegisterPythonDomainAndOps(options, ortApi)){
|
||||
if (status = RegisterPythonDomainAndOps(options, ortApi); status) {
|
||||
return status;
|
||||
}
|
||||
#endif // PYTHON_OP_SUPPORT
|
||||
|
||||
if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
|
||||
if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain); status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
size_t count = 0;
|
||||
const OrtCustomOp* c_ops = FetchPyCustomOps(count);
|
||||
while (c_ops != nullptr) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, c_ops); status) {
|
||||
return status;
|
||||
} else {
|
||||
pyop_nameset.emplace(c_ops->GetName(c_ops));
|
||||
|
@ -114,7 +114,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
auto ops = fx();
|
||||
while (*ops != nullptr) {
|
||||
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, *ops)) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, *ops); status) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -126,7 +126,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
|
||||
while (e_ops != nullptr) {
|
||||
if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, e_ops)) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, e_ops); status) {
|
||||
return status;
|
||||
}
|
||||
e_ops = ExternalCustomOps::instance().GetNextOp(idx);
|
||||
|
|
Загрузка…
Ссылка в новой задаче