Partner team's code security fixings (#300)

This commit is contained in:
Wenbing Li 2022-10-05 16:10:34 -07:00 коммит произвёл GitHub
Родитель 72790873e5
Коммит 7fc0224410
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 169 добавлений и 123 удалений

Просмотреть файл

@ -70,19 +70,25 @@ struct OrtTensorDimensions : std::vector<int64_t> {
template <typename... Args>
class CuopContainer {
public:
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
ocos_list_.push_back(nullptr);
}
~CuopContainer() {
// skip the last null pointer.
for (auto i = 0; i < ocos_list_.size() - 1; i++) {
delete ocos_list_[i];
if (0 < ocos_list_.size()) {
for (size_t i = 0; i < ocos_list_.size() - 1; i++) {
delete ocos_list_[i];
}
}
ocos_list_.clear();
}
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
const OrtCustomOp** GetList() {
return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
}

Просмотреть файл

@ -32,16 +32,18 @@ struct KernelGaussianBlur : BaseKernel {
OrtTensorDimensions input_data_dimensions(ort_, input_data);
int n = input_data_dimensions[0];
int h = input_data_dimensions[1];
int w = input_data_dimensions[2];
int c = input_data_dimensions[3];
int n = static_cast<int>(input_data_dimensions[0]);
int h = static_cast<int>(input_data_dimensions[1]);
int w = static_cast<int>(input_data_dimensions[2]);
int c = static_cast<int>(input_data_dimensions[3]);
(void)n;
(void)c;
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
cv::Mat output_image;
cv::GaussianBlur(input_image,
output_image,
cv::Size(ksize[1], ksize[0]),
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
sigma[0], sigma[1], cv::BORDER_DEFAULT);
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
@ -70,7 +72,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
}
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

Просмотреть файл

@ -42,7 +42,7 @@ struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
@ -50,7 +50,7 @@ struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};

Просмотреть файл

@ -48,7 +48,7 @@ struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
@ -56,7 +56,7 @@ struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
return 2;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};

Просмотреть файл

@ -55,6 +55,6 @@ const char* CustomOpSegmentExtraction::GetName() const {
return "SegmentExtraction";
};
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t index) const {
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -30,7 +30,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
T* p_output = ort_.GetTensorMutableData<T>(v);
int64_t out_size = dim_out.Size();
memset(p_output, 0, out_size * sizeof(T));
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(T)));
// The implementation is naive. It could be parallelized and
// use SIMD instructions to be faster.

Просмотреть файл

@ -38,7 +38,7 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
for (int i = 0; i < data.size(); i++) {
for (size_t i = 0; i < data.size(); i++) {
data_ptr[i] = data[i];
}
}

Просмотреть файл

@ -5,6 +5,7 @@
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output) {
(void)context;
OrtTensorDimensions dimensions(ort, value);
size_t len = static_cast<size_t>(dimensions.Size());
size_t data_len;
@ -15,14 +16,16 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
output.resize(len);
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
if (i < len - 1)
result[offsets[i + (int64_t)1]] = '\0';
output[i] = result.data() + offsets[i];
if (i < static_cast<int64_t>(len) - 1)
result[offsets[static_cast<size_t>(i + (int64_t)1)]] = '\0';
output[static_cast<size_t>(i)] = result.data() + offsets[static_cast<size_t>(i)];
}
}
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output) {
(void)ort;
(void)context;
std::vector<const char*> temp(value.size());
for (size_t i = 0; i < value.size(); ++i) {
temp[i] = value[i].c_str();

Просмотреть файл

@ -14,7 +14,7 @@ inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
for (size_t i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}
@ -31,7 +31,7 @@ inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
for (size_t i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}

Просмотреть файл

@ -37,7 +37,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
std::vector<std::string> result;
std::vector<int64_t> result_dimension;
for (int i = 0; i < value.size(); i++) {
for (size_t i = 0; i < value.size(); i++) {
if (!mask[i]) {
continue;
}

Просмотреть файл

@ -11,7 +11,7 @@ class BroadcastIteratorRight {
public:
BroadcastIteratorRight(const std::vector<int64_t>& shape1,
const std::vector<int64_t>& shape2,
const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) {
if (shape2.size() > shape1.size())
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
shape2_.resize(shape1_.size());
@ -82,7 +82,7 @@ class BroadcastIteratorRight {
}
template <typename TCMP>
void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) {
if (pos != 0)
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
while (!end()) {

Просмотреть файл

@ -155,7 +155,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
std::vector<int64_t> shape_out{size - 1, max_col};
int64_t shape_out_size = shape_out[0] * shape_out[1];
std::vector<std::string> dense(max_col * (size - 1));
std::vector<std::string> dense(static_cast<size_t>(max_col * (size - 1)));
int64_t pos = 0;
int64_t j, pos_end;
for (int64_t i = 0; i < size - 1; ++i) {
@ -165,7 +165,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
dense[pos] = input[j];
dense[static_cast<size_t>(pos)] = input[static_cast<size_t>(j)];
}
pos = pos_end;
}

Просмотреть файл

@ -42,7 +42,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
re2::StringPiece piece(str_rewrite[0]);
@ -50,11 +50,11 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
// Do computation
if (global_replace_) {
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
}
} else {
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
re2::RE2::Replace(&(str_input[i]), reg, piece);
}
}

Просмотреть файл

@ -49,7 +49,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
std::vector<std::string_view> tokens;
std::vector<int64_t> begin_offsets;
std::vector<int64_t> end_offsets;
RegexSplitImpl(str_input[i], reg,
RegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
include_delimiter, keep_reg,
tokens, begin_offsets, end_offsets);
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());

Просмотреть файл

@ -29,7 +29,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, right, right_value);
// reuse left_value as output to save memory
for (int i = 0; i < left_value.size(); i++) {
for (size_t i = 0; i < left_value.size(); i++) {
left_value[i].append(right_value[i]);
}

Просмотреть файл

@ -47,7 +47,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
auto regex_flag = std::regex_constants::optimize | std::regex_constants::ECMAScript;
@ -58,11 +58,11 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
std::regex reg(str_pattern[0], regex_flag);
if (global_replace_) {
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0]);
}
} else {
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0], std::regex_constants::format_first_only);
}
}

Просмотреть файл

@ -54,7 +54,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
std::vector<std::string_view> tokens;
std::vector<int64_t> begin_offsets;
std::vector<int64_t> end_offsets;
ECMARegexSplitImpl(str_input[i], reg,
ECMARegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
include_delimiter, keep_reg,
tokens, begin_offsets, end_offsets);
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());

Просмотреть файл

@ -32,9 +32,9 @@ void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
std::vector<T>& end_offsets) {
size_t prev_pos = 0;
for (auto it = std::sregex_iterator(input.begin(), input.end(), pattern); it != std::sregex_iterator(); it++) {
size_t cur_pos = it->position();
size_t matched_length = it->length();
if (prev_pos != it->position()) {
int cur_pos = static_cast<int>(it->position());
int matched_length = static_cast<int>(it->length());
if (static_cast<decltype(it->position())>(prev_pos) != it->position()) {
tokens.emplace_back(input.c_str() + prev_pos, cur_pos - prev_pos);
begin_offsets.push_back(prev_pos);
end_offsets.push_back(cur_pos);

Просмотреть файл

@ -33,12 +33,12 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
size_t nb = static_cast<size_t>(*p_num_buckets);
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
out[i] = static_cast<int64_t>(Hash64(str_input[i].c_str(), str_input[i].size()) % nb);
}
}
@ -92,12 +92,12 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
size_t nb = static_cast<size_t>(*p_num_buckets);
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
out[i] = static_cast<int64_t>(util::Fingerprint64(str_input[i].c_str(), str_input[i].size()) % nb);
}
}

Просмотреть файл

@ -30,7 +30,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
if (X.size() != 1)
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
} else {
if (*axis < 0 || *axis >= dimensions.size())
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
}
@ -49,18 +49,18 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
std::vector<std::string> out(size);
std::vector<std::string> out(static_cast<size_t>(size));
if (dimensions.size() > 0) {
if (X.size() > 0) {
// Do computation
int64_t h = 1;
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
for (size_t i = static_cast<size_t>(*axis + 1); i < dimensions.size(); ++i) {
h *= dimensions[i];
}
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[*axis] - 1;
int64_t n_red = dimensions[static_cast<size_t>(*axis)] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
@ -68,10 +68,10 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[index] << sep[0];
st << X[static_cast<size_t>(index)] << sep[0];
}
st << X[index];
out[pos] = st.str();
st << X[static_cast<size_t>(index)];
out[static_cast<size_t>(pos)] = st.str();
}
}
} else {

Просмотреть файл

@ -16,7 +16,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
for (size_t i = 0; i < X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) {return static_cast<char>(ToLower(c));});
}

Просмотреть файл

@ -36,10 +36,10 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
if (delimiter.size() == 0) {
char word[2] = "a";
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
const std::string& str = X[static_cast<size_t>(row)];
if (str.empty())
continue;
maxc = str.size() > maxc ? str.size() : maxc;
maxc = static_cast<int64_t>(static_cast<int64_t>(str.size()) > maxc ? str.size() : maxc);
for (auto it = str.begin(); it != str.end(); ++it) {
word[0] = *it;
words.push_back(word);
@ -51,7 +51,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
bool keep = !(*skip_empty);
std::size_t current, previous = 0;
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
const std::string& str = X[static_cast<size_t>(row)];
if (str.empty())
continue;
previous = 0;

Просмотреть файл

@ -91,7 +91,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
std::vector<std::string_view> value_strs = SplitString(v, " ", true);
int64_t value;
for (int i = 0; i < value_strs.size(); i++) {
for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);

Просмотреть файл

@ -16,8 +16,8 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), ::toupper);
for (size_t i = 0; i < X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c){ return static_cast<char>(::toupper(c)); });
}
// Fills the output

Просмотреть файл

@ -28,7 +28,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
// only hit when the key is a scalar and the input is a vector
output_dim = input_dim;
} else {
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != vector_len_) {
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
}
@ -37,9 +37,9 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
}
std::vector<int64_t> key(vector_len_);
for (int i = 0; i < input_dim.Size(); i += vector_len_) {
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
//construct key
for (int j = 0; j < vector_len_; j++) {
for (size_t j = 0; j < vector_len_; j++) {
key[j] = ptr[j];
}
@ -94,7 +94,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
std::vector<std::string_view> value_strs = SplitString(v, " ", true);
int64_t value;
for (int i = 0; i < value_strs.size(); i++) {
for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) {
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);

Просмотреть файл

@ -10,8 +10,12 @@
#include <algorithm>
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
do_lower_case_(do_lower_case), tokenize_chinese_chars_(tokenize_chinese_chars), strip_accents_(strip_accents), tokenize_punctuation_(tokenize_punctuation),
remove_control_chars_(remove_control_chars){}
do_lower_case_(do_lower_case),
strip_accents_(strip_accents),
tokenize_chinese_chars_(tokenize_chinese_chars),
tokenize_punctuation_(tokenize_punctuation),
remove_control_chars_(remove_control_chars)
{}
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
std::vector<ustring> result;

Просмотреть файл

@ -5,8 +5,8 @@
BertTokenizerVocab::BertTokenizerVocab(std::string_view vocab) : raw_vocab_(vocab) {
auto tokens = SplitString(raw_vocab_, "\r\n", true);
for (int i = 0; i < tokens.size(); i++) {
(vocab_)[tokens[i]] = i;
for (size_t i = 0; i < tokens.size(); i++) {
(vocab_)[tokens[i]] = static_cast<int32_t>(i);
}
}
@ -40,11 +40,16 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
}
WordpieceTokenizer::WordpieceTokenizer(
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word)
: vocab_(std::move(vocab)), unk_token_(std::move(unk_token))
, suffix_indicator_(std::move(suffix_indicator))
, max_input_chars_per_word_(max_input_chars_per_word) {
std::shared_ptr<BertTokenizerVocab> vocab,
ustring unk_token,
ustring suffix_indicator,
int max_input_chars_per_word
) :
max_input_chars_per_word_(max_input_chars_per_word),
suffix_indicator_(std::move(suffix_indicator)),
unk_token_(std::move(unk_token)),
vocab_(std::move(vocab))
{
unk_token_id_ = vocab_->FindTokenId(unk_token_);
}
@ -92,7 +97,7 @@ std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& toke
}
void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
if (token.size() > max_input_chars_per_word_) {
if (static_cast<int64_t>(token.size()) > max_input_chars_per_word_) {
tokenized_result.push_back(unk_token_);
return;
}
@ -127,13 +132,13 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
}
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int32_t max_len) {
if ((max_len > 0) && (max_len < ids.size())) {
if ((max_len > 0) && (static_cast<size_t>(max_len) < ids.size())) {
ids.resize(max_len);
}
}
void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len) {
if (max_len < 0 || (ids1.size() + ids2.size() <= max_len)) {
if (max_len < 0 || (ids1.size() + ids2.size() <= static_cast<size_t>(max_len))) {
return;
}
@ -145,7 +150,7 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>
case TruncateStrategyType::LONGEST_FIRST:
case TruncateStrategyType::LONGEST_FROM_BACK:
if ((ids1_keep_len > half_max_len) && (ids2_keep_len > half_max_len)) {
if ((ids1_keep_len > static_cast<size_t>(half_max_len)) && (ids2_keep_len > static_cast<size_t>(half_max_len))) {
ids1_keep_len = static_cast<size_t>(max_len) - half_max_len;
ids2_keep_len = half_max_len;
} else if (ids2_keep_len > ids1_keep_len) {
@ -173,12 +178,24 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>
}
BertTokenizer::BertTokenizer(
const std::string& vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token,
ustring sep_token, ustring pad_token, ustring cls_token, ustring mask_token,
bool tokenize_chinese_chars, bool strip_accents, ustring suffix_indicator, int32_t max_len,
const std::string& truncation_strategy)
: do_basic_tokenize_(do_basic_tokenize), max_length_(max_len)
, truncate_(std::make_unique<TruncateStrategy>(truncation_strategy)) {
const std::string& vocab,
bool do_lower_case,
bool do_basic_tokenize,
ustring unk_token,
ustring sep_token,
ustring pad_token,
ustring cls_token,
ustring mask_token,
bool tokenize_chinese_chars,
bool strip_accents,
ustring suffix_indicator,
int32_t max_len,
const std::string& truncation_strategy
) :
max_length_(max_len),
do_basic_tokenize_(do_basic_tokenize),
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy))
{
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);

Просмотреть файл

@ -1,25 +1,36 @@
#include "bert_tokenizer_decoder.hpp"
BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, std::string unk_token, std::string sep_token, std::string pad_token,
std::string cls_token,std::string mask_token,std::string suffix_indicator) : raw_vocab_(vocab), unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
BertTokenizerDecoder::BertTokenizerDecoder(
std::string vocab,
std::string unk_token,
std::string sep_token,
std::string pad_token,
std::string cls_token,
std::string mask_token,
std::string suffix_indicator
) :
unk_token_(unk_token),
suffix_indicator_(suffix_indicator),
raw_vocab_(vocab)
{
auto tokens = SplitString(raw_vocab_, "\n", true);
vocab_.reserve(tokens.size());
for (int i = 0; i < tokens.size(); i++) {
for (size_t i = 0; i < tokens.size(); i++) {
auto& token = tokens[i];
if (token == unk_token) {
unk_token_id_ = i;
unk_token_id_ = static_cast<int32_t>(i);
}
if (token == sep_token) {
sep_token_id_ = i;
sep_token_id_ = static_cast<int32_t>(i);
}
if (token == pad_token) {
sep_token_id_ = i;
sep_token_id_ = static_cast<int32_t>(i);
}
if (token == cls_token) {
cls_token_id_ = i;
cls_token_id_ = static_cast<int32_t>(i);
}
if (token == mask_token) {
mask_token_id_ = i;
mask_token_id_ = static_cast<int32_t>(i);
}
if (token.rfind(suffix_indicator_, 0) == 0) {
@ -42,7 +53,7 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
}
// deal with unk ids
if (id >= vocab_.size() || id < 0) {
if (id < 0 || static_cast<size_t>(id) >= vocab_.size()) {
if (!result.empty()) {
result.push_back(' ');
}
@ -51,7 +62,7 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
}
// skip first substr
if (result.empty() && is_substr_[id]) {
if (result.empty() && is_substr_[static_cast<size_t>(id)]) {
continue;
}
@ -59,11 +70,11 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
// we needn't add a space at the beginning of the output
// we needn't add a space when the token is a substr (such as ##ing)
// we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
if (!(result.empty() || is_substr_[static_cast<size_t>(id)] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
result.push_back(' ');
}
result.append(vocab_[id]);
result.append(vocab_[static_cast<size_t>(id)]);
pre_token = id;
}
@ -76,8 +87,8 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
return true;
}
auto pre_char = ustring(vocab_[pre_token_id]).back();
auto cur_char = ustring(vocab_[new_token_id])[0];
auto pre_char = ustring(vocab_[static_cast<size_t>(pre_token_id)]).back();
auto cur_char = ustring(vocab_[static_cast<size_t>(new_token_id)])[0];
// normal punctuation
if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {

Просмотреть файл

@ -15,7 +15,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
if (model_ptr == nullptr) {
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
@ -24,7 +24,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
if (HasAttribute("max_sentence")) {
max_sentence = ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence");
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
}
}
@ -42,10 +42,10 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, input, input_data);
std::string& input_string = input_data[0];
int max_length = 2 * input_string.size() + 1;
int max_length = static_cast<int>(2 * input_string.size() + 1);
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.get(), nullptr, nullptr, max_length, model_.get());
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
if (output_length < 0) {
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
}

Просмотреть файл

@ -167,8 +167,8 @@ class VocabData {
line = std::regex_replace(line, std::regex("\r"), "");
ustring line_32(line);
int id = static_cast<int>(vocab_map_.size());
if (auto it = vocab_map_.find(line); it != vocab_map_.end()) {
id = it->second;
if (auto nestedIt = vocab_map_.find(line); nestedIt != vocab_map_.end()) {
id = nestedIt->second;
} else {
vocab_map_[line] = id;
}
@ -237,7 +237,7 @@ class VocabData {
}
const std::string& IdToToken(int id) const {
if ((id < 0) || (id >= id2token_map_.size())) {
if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) {
ORT_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
}
return id2token_map_[id];
@ -500,7 +500,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
TokenWithRegularExp regcmp;
for (auto& seg_id : special_token_split_res) {
if (res.size() >= max_length) break;
if (static_cast<int64_t>(res.size()) >= max_length) break;
if (seg_id.second != -1) {
res.push_back(seg_id.second);
@ -512,7 +512,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
const char32_t* ptr = cur_input.c_str();
regcmp.Set(ptr);
while (res.size() < max_length) {
while (static_cast<int64_t>(res.size()) < max_length) {
auto [b, tok] = regcmp.GetNextToken();
if (!b) break;
@ -526,7 +526,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
bbpe_tokenizer_->bpe(byte_list_);
for (auto p : byte_list_) {
if (res.size() >= max_length) {
if (static_cast<int64_t>(res.size()) >= max_length) {
break;
}
@ -556,7 +556,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
max_length = std::max(max_length, res.size());
}
} else {
max_length = padding_length_;
max_length = static_cast<size_t>(padding_length_);
}
OrtTensorDimensions output_dim = input_dim;
@ -574,7 +574,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
idx++;
}
for (int i = res.size(); i < max_length; i++) {
for (size_t i = res.size(); i < max_length; i++) {
token[idx] = 0;
mask[idx] = 0;
idx++;
@ -594,14 +594,14 @@ size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t index) const {
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
return 2;
}
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t index) const {
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}

Просмотреть файл

@ -12,9 +12,9 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes;
if (base64_decode(model_as_string, model_as_bytes)) {
model_proto.ParseFromArray(model_as_bytes.data(), model_as_bytes.size());
model_proto.ParseFromArray(model_as_bytes.data(), static_cast<int>(model_as_bytes.size()));
} else {
model_proto.ParseFromArray(model_as_string.c_str(), model_as_string.size());
model_proto.ParseFromArray(model_as_string.c_str(), static_cast<int>(model_as_string.size()));
}
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
if (!status.ok())
@ -46,6 +46,9 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
const OrtValue* ort_add_rev = ort_.KernelContext_GetInput(context, 5);
const bool* p_add_rev = ort_.GetTensorData<bool>(ort_add_rev);
(void)p_nbest_size;
(void)p_alpha;
// Verifications
_check_dimension_constant(ort_, ort_nbest_size, "nbest_size");
_check_dimension_constant(ort_, ort_alpha, "alpha");

Просмотреть файл

@ -24,7 +24,7 @@ KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtK
}
}
void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator,
void KernelWordpieceTokenizer_Split(const std::u32string& /*suffix_indicator*/,
const std::u32string& text,
std::vector<std::u32string>& words) {
ustring space(" ");
@ -81,7 +81,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
KernelWordpieceTokenizer_Split(suffix_indicator, *it, words);
for (auto itk = words.begin(); itk != words.end(); ++itk) {
if (itk->size() > max_input_chars_per_word) {
if (static_cast<int64_t>(itk->size()) > max_input_chars_per_word) {
indices.push_back(-1);
tokens.push_back(unk_token);
continue;
@ -107,7 +107,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
is_bad = true;
break;
}
indices.push_back(cur_substr);
indices.push_back(static_cast<int32_t>(cur_substr));
tokens.push_back(ustring(token));
start = end;
}
@ -153,13 +153,13 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
int64_t i;
for (i = 0; i < size_row_lengths[0]; ++i) {
ptr_row_lengths[i] = row_begins[i];
ptr_row_begins[i] = row_begins[i];
ptr_limit_values[i] = row_begins[i + 1];
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
ptr_row_begins[i] = row_begins[static_cast<size_t>(i)];
ptr_limit_values[i] = row_begins[static_cast<size_t>(i + 1)];
}
i = size_row_lengths[0];
ptr_row_lengths[i] = row_begins[i];
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
}
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {

Просмотреть файл

@ -69,12 +69,12 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
OrtStatus* status = nullptr;
#if defined(PYTHON_OP_SUPPORT)
if (status = RegisterPythonDomainAndOps(options, ortApi)){
if (status = RegisterPythonDomainAndOps(options, ortApi); status) {
return status;
}
#endif // PYTHON_OP_SUPPORT
if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain); status) {
return status;
}
@ -84,7 +84,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
size_t count = 0;
const OrtCustomOp* c_ops = FetchPyCustomOps(count);
while (c_ops != nullptr) {
if (status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
if (status = ortApi->CustomOpDomain_Add(domain, c_ops); status) {
return status;
} else {
pyop_nameset.emplace(c_ops->GetName(c_ops));
@ -114,7 +114,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
auto ops = fx();
while (*ops != nullptr) {
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
if (status = ortApi->CustomOpDomain_Add(domain, *ops)) {
if (status = ortApi->CustomOpDomain_Add(domain, *ops); status) {
return status;
}
}
@ -126,7 +126,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
while (e_ops != nullptr) {
if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
if (status = ortApi->CustomOpDomain_Add(domain, e_ops)) {
if (status = ortApi->CustomOpDomain_Add(domain, e_ops); status) {
return status;
}
e_ops = ExternalCustomOps::instance().GetNextOp(idx);