diff --git a/.gitignore b/.gitignore index 130a530f..906adb07 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ out/ .setuptools-cmake-build/ onnxruntime-*-*-*/ temp_*.onnx +test/data/*.py # Compiled Dynamic libraries *.so diff --git a/onnxruntime_extensions/_cuops.py b/onnxruntime_extensions/_cuops.py index bf5d3f7d..1f733438 100644 --- a/onnxruntime_extensions/_cuops.py +++ b/onnxruntime_extensions/_cuops.py @@ -20,7 +20,7 @@ class CustomOp: def get_inputs(cls): return None @classmethod - def get_output(cls): return None + def get_outputs(cls): return None @classmethod def serialize_attr(cls, attrs): diff --git a/operators/string_utils.cc b/operators/string_utils.cc index 81c6d9c3..34b447d6 100644 --- a/operators/string_utils.cc +++ b/operators/string_utils.cc @@ -44,6 +44,53 @@ bool IsCJK(char32_t c) { || (c >= 0x2F800 && c <= 0x2FA1F); } +// Generated by tools/generate_unicode_category_table.py +bool IsSpace(char32_t c) { + if (c == 13||c == 32||c == 160||c == 8239||c == 8287||c == 12288) { + return true; + } + + if ((c >= 9 && c <= 10)||(c >= 8192 && c <= 8202)) { + return true; + } + + return false; +} + +// Generated by tools/generate_unicode_category_table.py +bool IsPunct(char32_t c) { + if (c == 161||c == 167||c == 171||c == 187||c == 191||c == 894||c == 903||c == 12336||c == 12349) { + return true; + } + + if ((c >= 33 && c <= 47)||(c >= 58 && c <= 64)||(c >= 91 && c <= 96)||(c >= 123 && c <= 126) + ||(c >= 182 && c <= 183)||(c >= 8208 && c <= 8231)||(c >= 8240 && c <= 8259) + || (c >= 8261 && c <= 8273)||(c >= 8275 && c <= 8286)||(c >= 12289 && c <= 12291) + ||(c >= 12296 && c <= 12305)||(c >= 12308 && c <= 12319)) { + return true; + } + + return false; +} + +// Generated by tools/generate_unicode_category_table.py +bool IsControl(char32_t c) { + if (c == 173||c == 907||c == 909||c == 930||c == 11930||c == 173790||c == 195102 + ) { + return true; + } + + if ((c >= 0 && c <= 8)||(c >= 11 && c <= 12)||(c >= 14 && c <= 31)||(c >= 128 && c <= 159) + ||(c >= 888 && c <= 889)||(c >= 896 && c <= 899)||(c >= 8203 && c <= 8207) + ||(c >= 8234 && c <= 8238)||(c >= 8288 && c <= 8302)||(c >= 12020 && c <= 12030) + ||(c >= 40957 && c <= 40958)||(c >= 64110 && c <= 64111)||(c >= 64218 && c <= 64254) + ||(c >= 177973 && c <= 177982)||(c >= 178206 && c <= 178207)||(c >= 183970 && c <= 183982)) { + return true; + } + + return false; +} + bool IsAccent(char32_t c) { // only support part of accent diff --git a/operators/string_utils.h b/operators/string_utils.h index 9c82f343..4b25b91c 100644 --- a/operators/string_utils.h +++ b/operators/string_utils.h @@ -59,6 +59,12 @@ bool IsCJK(char32_t c); bool IsAccent(char32_t c); +bool IsSpace(char32_t c); + +bool IsPunct(char32_t c); + +bool IsControl(char32_t c); + char32_t StripAccent(char32_t c); uint64_t Hash64(const char* data, size_t n, uint64_t seed); diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc index 598d8189..f15a990d 100644 --- a/operators/tokenizer/basic_tokenizer.cc +++ b/operators/tokenizer/basic_tokenizer.cc @@ -54,22 +54,20 @@ std::vector BasicTokenizer::Tokenize(ustring text) { } // 0x2019 unicode is not punctuation in some Linux platform, - // to be consistent, take it as punctatuation always. - if (tokenize_punctuation_ && (::iswpunct(c) || c == wint_t(0x2019))) { + // to be consistent, take it as punctuation. + if (tokenize_punctuation_ && IsPunct(c)) { push_current_token_and_clear(); push_single_char_and_clear(c); continue; } // split by space - if (::iswspace(c)) { + if (IsSpace(c)) { push_current_token_and_clear(); continue; } - // iscntrl will judge \t\f\n\r as control char - // but it has been filter by isspace(c) - if (remove_control_chars_ && ::iswcntrl(c)) { + if (remove_control_chars_ && IsControl(c)) { continue; } diff --git a/test/data/test_bert_tokenizer1.onnx b/test/data/test_bert_tokenizer1.onnx new file mode 100644 index 00000000..b3740160 Binary files /dev/null and b/test/data/test_bert_tokenizer1.onnx differ diff --git a/test/data/test_segment_extraction.onnx b/test/data/test_segment_extraction.onnx new file mode 100644 index 00000000..06fa17f6 Binary files /dev/null and b/test/data/test_segment_extraction.onnx differ diff --git a/test/shared_test/test_ortops_tokenizer.cc b/test/shared_test/test_ortops_tokenizer.cc new file mode 100644 index 00000000..d871a49b --- /dev/null +++ b/test/shared_test/test_ortops_tokenizer.cc @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "ocos.h" +#include "test_kernel.hpp" + +TEST(utils, test_bert_tokenizer) { + auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); + + std::vector inputs(1); + inputs[0].name = "text"; + inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + inputs[0].dims = {1}; + inputs[0].values_string = {"We look forward to welcoming you to our stores. Whether you shop in a store or shop online, our Specialists can help you buy the products you love."}; + + std::vector outputs(3); + outputs[0].name = "input_ids"; + outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[0].dims = {34}; + outputs[0].values_int64 = {101, 1284, 1440, 1977, 1106, 20028, 1128, 1106, 1412, 4822, 119, 13197, 1128, 4130, 1107, 170, 2984, 1137, 4130, 3294, 117, 1412, 25607, 1116, 1169, 1494, 1128, 4417, 1103, 2982, 1128, 1567, 119, 102}; + + outputs[1].name = "token_type_ids"; + outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[1].dims = {34}; + outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + outputs[2].name = "attention_mask"; + outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[2].dims = {34}; + outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + std::filesystem::path model_path = __FILE__; + model_path = model_path.parent_path(); + model_path /= ".."; + model_path /= "data"; + model_path /= "test_bert_tokenizer1.onnx"; + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath()); + + + inputs[0].name = "text"; + inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + inputs[0].dims = {1}; + inputs[0].values_string = {"本想好好的伤感 想放任 但是没泪痕"}; + + outputs[0].name = "input_ids"; + outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[0].dims = {17}; + outputs[0].values_int64 = {101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102}; + + outputs[1].name = "token_type_ids"; + outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[1].dims = {17}; + outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + outputs[2].name = "attention_mask"; + outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[2].dims = {17}; + outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath()); + + inputs[0].name = "text"; + inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + inputs[0].dims = {1}; + inputs[0].values_string = {"ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~"}; + + outputs[0].name = "input_ids"; + outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[0].dims = {71}; + outputs[0].values_int64 = {101, 13807, 11189, 8101, 27073, 27073, 12738, 11607, 2346, 2346, 2346, 2346, 2346, 2591, 2591, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 106, 108, 109, 110, 111, 113, 2545, 137, 17599, 7301, 4964, 119, 3254, 114, 115, 116, 117, 118, 119, 120, 131, 132, 133, 134, 135, 136, 137, 164, 165, 166, 167, 168, 169, 196, 197, 198, 199, 102}; + + outputs[1].name = "token_type_ids"; + outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[1].dims = {71}; + outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + outputs[2].name = "attention_mask"; + outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + outputs[2].dims = {71}; + outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath()); +} \ No newline at end of file diff --git a/test/test_bert_tokenizer.py b/test/test_bert_tokenizer.py index 8747ed2b..efc73459 100644 --- a/test/test_bert_tokenizer.py +++ b/test/test_bert_tokenizer.py @@ -4,24 +4,27 @@ import numpy as np import transformers from onnxruntime_extensions import PyOrtFunction, BertTokenizer -bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased') -bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') def _get_test_data_file(*sub_dirs): test_dir = Path(__file__).parent return str(test_dir.joinpath(*sub_dirs)) +bert_cased_tokenizer = transformers.BertTokenizer(_get_test_data_file('data', 'bert_basic_cased_vocab.txt'), False, + strip_accents=True) + + def _run_basic_case(input, vocab_path): - t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0) + t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1) result = t2stc([input]) expect_result = bert_cased_tokenizer.encode_plus(input) np.testing.assert_array_equal(result[0], expect_result['input_ids']) np.testing.assert_array_equal(result[1], expect_result['token_type_ids']) np.testing.assert_array_equal(result[2], expect_result['attention_mask']) + def _run_combined_case(input, vocab_path): - t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0) + t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1) result = t2stc(input) expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1]) np.testing.assert_array_equal(result[0], expect_result['input_ids']) @@ -34,17 +37,19 @@ class TestBertTokenizer(unittest.TestCase): def test_text_to_case1(self): _run_basic_case(input="Input 'text' must not be empty.", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) + _run_basic_case( + input="ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~", + vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) _run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) + _run_basic_case(input="本想好好的伤感 想放任 但是没泪痕", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) _run_basic_case(input="网 易 云 音 乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) _run_basic_case(input="cat is playing toys", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) _run_basic_case(input="cat isnot playing toyssss", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) - _run_basic_case(input="cat isnot playing toyssss", - vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) - _run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) - + _run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], + vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) if __name__ == "__main__": diff --git a/tools/generate_unicode_category_table.py b/tools/generate_unicode_category_table.py new file mode 100644 index 00000000..2eaa1658 --- /dev/null +++ b/tools/generate_unicode_category_table.py @@ -0,0 +1,144 @@ +import unicodedata + + +def _is_whitespace(char): + """Checks whether `char` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `char` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `char` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def find_expect_char_in_range(judge_fun, start, end): + result = [] + for c in range(start, end): + if judge_fun(chr(c)): + result.append(c) + return result + + +def find_ranges(nums): + nums = sorted(set(nums)) + gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e] + edges = iter(nums[:1] + sum(gaps, []) + nums[-1:]) + return list(zip(edges, edges)) + + +def find_expect_category(category_func): + expect_category_set = [] + + # ASCII + expect_category_set += find_expect_char_in_range(category_func, 0, 0x7F) + + # C1 Controls and Latin-1 Supplement + expect_category_set += find_expect_char_in_range(category_func, 0x80, 0xFF) + + # Latin Extended-A + expect_category_set += find_expect_char_in_range(category_func, 0x100, 0x17F) + + # Latin Extended-B + expect_category_set += find_expect_char_in_range(category_func, 0x180, 0x24F) + + # IPA Extensions + expect_category_set += find_expect_char_in_range(category_func, 0x250, 0x2AF) + + # Spacing Modifier Letters + expect_category_set += find_expect_char_in_range(category_func, 0x2B0, 0x2FF) + + # Combining Diacritical Marks + expect_category_set += find_expect_char_in_range(category_func, 0x300, 0x36F) + + # Greek/Coptic + expect_category_set += find_expect_char_in_range(category_func, 0x370, 0x3FF) + + # Cyrillic and Cyrillic Supplement + expect_category_set += find_expect_char_in_range(category_func, 0x400, 0x52F) + + # General Punctuation + expect_category_set += find_expect_char_in_range(category_func, 0x2000, 0x206F) + + # CJK Radicals Supplement + expect_category_set += find_expect_char_in_range(category_func, 0x2E80, 0x2EFF) + + # CJK Symbols and Punctuation + expect_category_set += find_expect_char_in_range(category_func, 0x3000, 0x303F) + + # CJK + expect_category_set += find_expect_char_in_range(category_func, 0x4E00, 0x9FFF) + expect_category_set += find_expect_char_in_range(category_func, 0x3400, 0x4DBF) + expect_category_set += find_expect_char_in_range(category_func, 0x20000, 0x2A6DF) + expect_category_set += find_expect_char_in_range(category_func, 0x2A700, 0x2B73F) + expect_category_set += find_expect_char_in_range(category_func, 0x2B740, 0x2CEAF) + expect_category_set += find_expect_char_in_range(category_func, 0xF900, 0xFAFF) + expect_category_set += find_expect_char_in_range(category_func, 0x2F800, 0x2FA1F) + + return find_ranges(expect_category_set) + +def print_range(ranges): + single_set = [] + pair_set = [] + for r in ranges: + start, end = r + if start == end: + single_set.append(start) + else: + pair_set.append(r) + + output = "if (" + for i in range(len(single_set)): + if i != 0: + output += "||" + output += f"c == {single_set[i]}" + output += ") {\n return true;\n}\n\n" + + output += "if (" + for i in range(len(pair_set)): + if i != 0: + output += "||" + start, end = pair_set[i] + output += f"(c >= {start} && c <= {end})" + output += ") {\n return true;\n}\n\nreturn false;\n" + print(output) + + +print("\nis_whitespace:") +print_range(find_expect_category(_is_whitespace)) + +print("\nis_punctuation:") +print_range(find_expect_category(_is_punctuation)) + +print("\nis_control:") +print_range(find_expect_category(_is_control)) + +