Add native test for bert tokenizer (#173)
* add native test for bert tokenizer * add python test * fix unicode category Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Родитель
70aa18e14e
Коммит
448518534c
|
@ -35,6 +35,7 @@ out/
|
|||
.setuptools-cmake-build/
|
||||
onnxruntime-*-*-*/
|
||||
temp_*.onnx
|
||||
test/data/*.py
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so
|
||||
|
|
|
@ -20,7 +20,7 @@ class CustomOp:
|
|||
def get_inputs(cls): return None
|
||||
|
||||
@classmethod
|
||||
def get_output(cls): return None
|
||||
def get_outputs(cls): return None
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
|
|
|
@ -44,6 +44,53 @@ bool IsCJK(char32_t c) {
|
|||
|| (c >= 0x2F800 && c <= 0x2FA1F);
|
||||
}
|
||||
|
||||
// Generated by tools/generate_unicode_category_table.py
|
||||
bool IsSpace(char32_t c) {
|
||||
if (c == 13||c == 32||c == 160||c == 8239||c == 8287||c == 12288) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((c >= 9 && c <= 10)||(c >= 8192 && c <= 8202)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Generated by tools/generate_unicode_category_table.py
|
||||
bool IsPunct(char32_t c) {
|
||||
if (c == 161||c == 167||c == 171||c == 187||c == 191||c == 894||c == 903||c == 12336||c == 12349) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((c >= 33 && c <= 47)||(c >= 58 && c <= 64)||(c >= 91 && c <= 96)||(c >= 123 && c <= 126)
|
||||
||(c >= 182 && c <= 183)||(c >= 8208 && c <= 8231)||(c >= 8240 && c <= 8259)
|
||||
|| (c >= 8261 && c <= 8273)||(c >= 8275 && c <= 8286)||(c >= 12289 && c <= 12291)
|
||||
||(c >= 12296 && c <= 12305)||(c >= 12308 && c <= 12319)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Generated by tools/generate_unicode_category_table.py
|
||||
bool IsControl(char32_t c) {
|
||||
if (c == 173||c == 907||c == 909||c == 930||c == 11930||c == 173790||c == 195102
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((c >= 0 && c <= 8)||(c >= 11 && c <= 12)||(c >= 14 && c <= 31)||(c >= 128 && c <= 159)
|
||||
||(c >= 888 && c <= 889)||(c >= 896 && c <= 899)||(c >= 8203 && c <= 8207)
|
||||
||(c >= 8234 && c <= 8238)||(c >= 8288 && c <= 8302)||(c >= 12020 && c <= 12030)
|
||||
||(c >= 40957 && c <= 40958)||(c >= 64110 && c <= 64111)||(c >= 64218 && c <= 64254)
|
||||
||(c >= 177973 && c <= 177982)||(c >= 178206 && c <= 178207)||(c >= 183970 && c <= 183982)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsAccent(char32_t c)
|
||||
{
|
||||
// only support part of accent
|
||||
|
|
|
@ -59,6 +59,12 @@ bool IsCJK(char32_t c);
|
|||
|
||||
bool IsAccent(char32_t c);
|
||||
|
||||
bool IsSpace(char32_t c);
|
||||
|
||||
bool IsPunct(char32_t c);
|
||||
|
||||
bool IsControl(char32_t c);
|
||||
|
||||
char32_t StripAccent(char32_t c);
|
||||
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
|
||||
|
|
|
@ -54,22 +54,20 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
}
|
||||
|
||||
// 0x2019 unicode is not punctuation in some Linux platform,
|
||||
// to be consistent, take it as punctatuation always.
|
||||
if (tokenize_punctuation_ && (::iswpunct(c) || c == wint_t(0x2019))) {
|
||||
// to be consistent, take it as punctuation.
|
||||
if (tokenize_punctuation_ && IsPunct(c)) {
|
||||
push_current_token_and_clear();
|
||||
push_single_char_and_clear(c);
|
||||
continue;
|
||||
}
|
||||
|
||||
// split by space
|
||||
if (::iswspace(c)) {
|
||||
if (IsSpace(c)) {
|
||||
push_current_token_and_clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// iscntrl will judge \t\f\n\r as control char
|
||||
// but it has been filter by isspace(c)
|
||||
if (remove_control_chars_ && ::iswcntrl(c)) {
|
||||
if (remove_control_chars_ && IsControl(c)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,83 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
TEST(utils, test_bert_tokenizer) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
inputs[0].values_string = {"We look forward to welcoming you to our stores. Whether you shop in a store or shop online, our Specialists can help you buy the products you love."};
|
||||
|
||||
std::vector<TestValue> outputs(3);
|
||||
outputs[0].name = "input_ids";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {34};
|
||||
outputs[0].values_int64 = {101, 1284, 1440, 1977, 1106, 20028, 1128, 1106, 1412, 4822, 119, 13197, 1128, 4130, 1107, 170, 2984, 1137, 4130, 3294, 117, 1412, 25607, 1116, 1169, 1494, 1128, 4417, 1103, 2982, 1128, 1567, 119, 102};
|
||||
|
||||
outputs[1].name = "token_type_ids";
|
||||
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[1].dims = {34};
|
||||
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
outputs[2].name = "attention_mask";
|
||||
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[2].dims = {34};
|
||||
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_bert_tokenizer1.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
inputs[0].values_string = {"本想好好的伤感 想放任 但是没泪痕"};
|
||||
|
||||
outputs[0].name = "input_ids";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {17};
|
||||
outputs[0].values_int64 = {101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102};
|
||||
|
||||
outputs[1].name = "token_type_ids";
|
||||
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[1].dims = {17};
|
||||
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
outputs[2].name = "attention_mask";
|
||||
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[2].dims = {17};
|
||||
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
inputs[0].values_string = {"ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~"};
|
||||
|
||||
outputs[0].name = "input_ids";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {71};
|
||||
outputs[0].values_int64 = {101, 13807, 11189, 8101, 27073, 27073, 12738, 11607, 2346, 2346, 2346, 2346, 2346, 2591, 2591, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 106, 108, 109, 110, 111, 113, 2545, 137, 17599, 7301, 4964, 119, 3254, 114, 115, 116, 117, 118, 119, 120, 131, 132, 133, 134, 135, 136, 137, 164, 165, 166, 167, 168, 169, 196, 197, 198, 199, 102};
|
||||
|
||||
outputs[1].name = "token_type_ids";
|
||||
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[1].dims = {71};
|
||||
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
outputs[2].name = "attention_mask";
|
||||
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[2].dims = {71};
|
||||
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
|
@ -4,24 +4,27 @@ import numpy as np
|
|||
import transformers
|
||||
from onnxruntime_extensions import PyOrtFunction, BertTokenizer
|
||||
|
||||
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
|
||||
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
bert_cased_tokenizer = transformers.BertTokenizer(_get_test_data_file('data', 'bert_basic_cased_vocab.txt'), False,
|
||||
strip_accents=True)
|
||||
|
||||
|
||||
def _run_basic_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
|
||||
result = t2stc([input])
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input)
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
|
||||
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
|
||||
|
||||
|
||||
def _run_combined_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
|
||||
result = t2stc(input)
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1])
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
|
@ -34,17 +37,19 @@ class TestBertTokenizer(unittest.TestCase):
|
|||
def test_text_to_case1(self):
|
||||
_run_basic_case(input="Input 'text' must not be empty.",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(
|
||||
input="ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="本想好好的伤感 想放任 但是没泪痕", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网 易 云 音 乐",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat is playing toys",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
import unicodedata
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `char` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `char` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `char` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_expect_char_in_range(judge_fun, start, end):
|
||||
result = []
|
||||
for c in range(start, end):
|
||||
if judge_fun(chr(c)):
|
||||
result.append(c)
|
||||
return result
|
||||
|
||||
|
||||
def find_ranges(nums):
|
||||
nums = sorted(set(nums))
|
||||
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e]
|
||||
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
|
||||
return list(zip(edges, edges))
|
||||
|
||||
|
||||
def find_expect_category(category_func):
|
||||
expect_category_set = []
|
||||
|
||||
# ASCII
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0, 0x7F)
|
||||
|
||||
# C1 Controls and Latin-1 Supplement
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x80, 0xFF)
|
||||
|
||||
# Latin Extended-A
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x100, 0x17F)
|
||||
|
||||
# Latin Extended-B
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x180, 0x24F)
|
||||
|
||||
# IPA Extensions
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x250, 0x2AF)
|
||||
|
||||
# Spacing Modifier Letters
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2B0, 0x2FF)
|
||||
|
||||
# Combining Diacritical Marks
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x300, 0x36F)
|
||||
|
||||
# Greek/Coptic
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x370, 0x3FF)
|
||||
|
||||
# Cyrillic and Cyrillic Supplement
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x400, 0x52F)
|
||||
|
||||
# General Punctuation
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2000, 0x206F)
|
||||
|
||||
# CJK Radicals Supplement
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2E80, 0x2EFF)
|
||||
|
||||
# CJK Symbols and Punctuation
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x3000, 0x303F)
|
||||
|
||||
# CJK
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x4E00, 0x9FFF)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x3400, 0x4DBF)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x20000, 0x2A6DF)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2A700, 0x2B73F)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2B740, 0x2CEAF)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0xF900, 0xFAFF)
|
||||
expect_category_set += find_expect_char_in_range(category_func, 0x2F800, 0x2FA1F)
|
||||
|
||||
return find_ranges(expect_category_set)
|
||||
|
||||
def print_range(ranges):
|
||||
single_set = []
|
||||
pair_set = []
|
||||
for r in ranges:
|
||||
start, end = r
|
||||
if start == end:
|
||||
single_set.append(start)
|
||||
else:
|
||||
pair_set.append(r)
|
||||
|
||||
output = "if ("
|
||||
for i in range(len(single_set)):
|
||||
if i != 0:
|
||||
output += "||"
|
||||
output += f"c == {single_set[i]}"
|
||||
output += ") {\n return true;\n}\n\n"
|
||||
|
||||
output += "if ("
|
||||
for i in range(len(pair_set)):
|
||||
if i != 0:
|
||||
output += "||"
|
||||
start, end = pair_set[i]
|
||||
output += f"(c >= {start} && c <= {end})"
|
||||
output += ") {\n return true;\n}\n\nreturn false;\n"
|
||||
print(output)
|
||||
|
||||
|
||||
print("\nis_whitespace:")
|
||||
print_range(find_expect_category(_is_whitespace))
|
||||
|
||||
print("\nis_punctuation:")
|
||||
print_range(find_expect_category(_is_punctuation))
|
||||
|
||||
print("\nis_control:")
|
||||
print_range(find_expect_category(_is_control))
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче