Disable c++ exceptions in onnxruntime-extensions. (#143)

* Disable c++ exceptions in onnxruntime-extensions.

* Remove cxx flags for extensions.

* Remove redundant lines.

Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Zuwei Zhao 2021-09-08 19:21:40 -05:00 коммит произвёл GitHub
Родитель 8649d98839
Коммит 6d7a865913
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 38 добавлений и 36 удалений

Просмотреть файл

@ -1,12 +1,16 @@
# If the oeprator needs the cpp exceptions supports, write down their names
if (OCOS_ENABLE_GPT2_TOKENIZER)
# gpt2 tokenizer depends on nlohmann_json in onnxruntime, which is old and cannot disable exceptions.
# could remove this limit when the nlohmann_json is updated in onnxruntime.
message(FATAL_ERROR "GPT2_TOKENIZER operator needs c++ exceptions support")
endif()
if (OCOS_ENABLE_WORDPIECE_TOKENIZER)
# wordpiece tokenizer depends on nlohmann_json in onnxruntime, which is old and cannot disable exceptions.
# could remove this limit when the nlohmann_json is updated in onnxruntime.
message(FATAL_ERROR "WORDPIECE_TOKENIZER operator needs c++ exceptions support")
endif()
if (OCOS_ENABLE_BLINGFIRE)
message(FATAL_ERROR "BLINGFIRE operator needs c++ exceptions support")
message(STATUS "BLINGFIRE operator needs c++ exceptions support, enable exceptions by default!")
endif()
if (OCOS_ENABLE_SPM_TOKENIZER)
message(FATAL_ERROR "SPM_TOKENIZER operator needs c++ exceptions support")

Просмотреть файл

@ -15,13 +15,13 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
OrtTensorDimensions dim_data(ort_, data);
OrtTensorDimensions dim_seg(ort_, segment_ids);
if (dim_data.size() == 0 || dim_seg.size() == 0)
throw std::runtime_error("Both inputs cannot be empty.");
ORT_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_ARGUMENT);
if (dim_seg.size() != 1)
throw std::runtime_error("segment_ids must a single tensor");
ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_ARGUMENT);
if (dim_data[0] != dim_seg[0])
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data.GetDims(),
" segment_ids shape: ", dim_seg.GetDims()));
" segment_ids shape: ", dim_seg.GetDims()), ORT_INVALID_ARGUMENT);
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
OrtTensorDimensions dim_out = dim_data;
@ -42,9 +42,9 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
const int64_t* p_seg = p_segment_ids;
for (; begin != end; ++p_seg) {
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
throw std::runtime_error(MakeString("segment_ids must be increasing but found ",
ORT_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
*(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), "."));
std::distance(p_segment_ids, p_seg), "."), ORT_INVALID_ARGUMENT);
p_out = p_output + *p_seg * in_stride;
p_out_end = p_out + in_stride;
for (; p_out != p_out_end; ++p_out, ++begin)
@ -86,6 +86,6 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error("Operator SegmentSum has 2 inputs.");
ORT_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -27,15 +27,15 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."));
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."));
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
throw std::runtime_error("pattern (second input) cannot be empty.");
ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
// Setup output
OrtTensorDimensions dimensions(ort_, input);

Просмотреть файл

@ -24,15 +24,15 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
str_pattern.size(), " values."));
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
if (str_keep_pattern.size() > 1)
throw std::runtime_error(MakeString(
ORT_CXX_API_THROW(MakeString(
"Third input must contain only one element. It has ",
str_keep_pattern.size(), " values."));
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
throw std::runtime_error("Splitting pattern cannot be empty.");
ORT_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input);
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
@ -106,7 +106,7 @@ ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(siz
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."));
ORT_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -29,7 +29,7 @@ class SpecialTokenMap {
auto it = token_map_.find(p_str);
if (it != token_map_.end()) {
if (it->second != p_id) {
throw std::runtime_error("Duplicate special tokens");
ORT_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
}
} else {
token_map_[p_str] = p_id;
@ -84,7 +84,7 @@ class SpecialTokenMap {
SpecialTokenInfo(ustring p_str, int p_id)
: str(std::move(p_str)), id(p_id) {
if (str.empty()) {
throw std::runtime_error("Empty special token.");
ORT_CXX_API_THROW("Empty special token.", ORT_INVALID_ARGUMENT);
}
}
};
@ -147,7 +147,7 @@ class VocabData {
if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' ');
if (pos == std::string::npos) {
throw std::runtime_error("Cannot know how to parse line: " + line);
ORT_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
}
std::string w1 = line.substr(0, pos);
std::string w2 = line.substr(pos + 1);
@ -231,14 +231,14 @@ class VocabData {
int TokenToID(const std::string& input) const {
auto it = vocab_map_.find(input);
if (it == vocab_map_.end()) {
throw std::runtime_error("Token not found: " + input);
ORT_CXX_API_THROW("Token not found: " + input, ORT_INVALID_ARGUMENT);
}
return it->second;
}
const std::string& IdToToken(int id) const {
if ((id < 0) || (id >= id2token_map_.size())) {
throw std::runtime_error("Invalid ID: " + std::to_string(id));
ORT_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
}
return id2token_map_[id];
}
@ -247,7 +247,7 @@ class VocabData {
int GetVocabIndex(const std::string& str) {
auto it = vocab_map_.find(str);
if (it == vocab_map_.end()) {
throw std::runtime_error("Cannot find word in vocabulary: " + str);
ORT_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
}
return it->second;
}
@ -467,12 +467,12 @@ KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
if (vocab.empty()) {
throw std::runtime_error("vocabulary shouldn't be empty.");
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
if (merges.empty()) {
throw std::runtime_error("merges shouldn't be empty.");
ORT_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
@ -480,7 +480,7 @@ KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info)
}
if (padding_length_ != -1 && padding_length_ <= 0) {
throw std::runtime_error("padding_length should be more than 0 or equal -1");
ORT_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
}
std::stringstream vocabu_stream(vocab);

Просмотреть файл

@ -200,4 +200,4 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
default:
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};
};

Просмотреть файл

@ -4,19 +4,17 @@ import sys
OPMAP_TO_CMAKE_FLAGS = {'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE',
'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
'WordpieceTokenizer': 'OCOS_ENABLE_WORDPIECE_TOKENIZER',
# Currently use one option for all string operators because their binary sizes are not large.
# Would probably split to more options like tokenizers in the future.
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
'StringConcat': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
'StringEqual': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucket': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucketFast': 'OCOS_ENABLE_TF_STRING',
'StringJoin': 'OCOS_ENABLE_TF_STRING',
'StringLength': 'OCOS_ENABLE_TF_STRING',
'StringLower': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
'StringSplit': 'OCOS_ENABLE_TF_STRING',
'StringToVector': 'OCOS_ENABLE_TF_STRING',
'StringUpper': 'OCOS_ENABLE_TF_STRING',