Add native test for bert tokenizer (#173)

* add native test for bert tokenizer

* add python test

* fix unicode category

Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Mojimi 2021-10-20 02:09:38 +08:00 коммит произвёл GitHub
Родитель 70aa18e14e
Коммит 448518534c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 299 добавлений и 15 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -35,6 +35,7 @@ out/
.setuptools-cmake-build/
onnxruntime-*-*-*/
temp_*.onnx
test/data/*.py
# Compiled Dynamic libraries
*.so

Просмотреть файл

@ -20,7 +20,7 @@ class CustomOp:
def get_inputs(cls): return None
@classmethod
def get_output(cls): return None
def get_outputs(cls): return None
@classmethod
def serialize_attr(cls, attrs):

Просмотреть файл

@ -44,6 +44,53 @@ bool IsCJK(char32_t c) {
|| (c >= 0x2F800 && c <= 0x2FA1F);
}
// Generated by tools/generate_unicode_category_table.py
bool IsSpace(char32_t c) {
if (c == 13||c == 32||c == 160||c == 8239||c == 8287||c == 12288) {
return true;
}
if ((c >= 9 && c <= 10)||(c >= 8192 && c <= 8202)) {
return true;
}
return false;
}
// Generated by tools/generate_unicode_category_table.py
bool IsPunct(char32_t c) {
if (c == 161||c == 167||c == 171||c == 187||c == 191||c == 894||c == 903||c == 12336||c == 12349) {
return true;
}
if ((c >= 33 && c <= 47)||(c >= 58 && c <= 64)||(c >= 91 && c <= 96)||(c >= 123 && c <= 126)
||(c >= 182 && c <= 183)||(c >= 8208 && c <= 8231)||(c >= 8240 && c <= 8259)
|| (c >= 8261 && c <= 8273)||(c >= 8275 && c <= 8286)||(c >= 12289 && c <= 12291)
||(c >= 12296 && c <= 12305)||(c >= 12308 && c <= 12319)) {
return true;
}
return false;
}
// Generated by tools/generate_unicode_category_table.py
bool IsControl(char32_t c) {
if (c == 173||c == 907||c == 909||c == 930||c == 11930||c == 173790||c == 195102
) {
return true;
}
if ((c >= 0 && c <= 8)||(c >= 11 && c <= 12)||(c >= 14 && c <= 31)||(c >= 128 && c <= 159)
||(c >= 888 && c <= 889)||(c >= 896 && c <= 899)||(c >= 8203 && c <= 8207)
||(c >= 8234 && c <= 8238)||(c >= 8288 && c <= 8302)||(c >= 12020 && c <= 12030)
||(c >= 40957 && c <= 40958)||(c >= 64110 && c <= 64111)||(c >= 64218 && c <= 64254)
||(c >= 177973 && c <= 177982)||(c >= 178206 && c <= 178207)||(c >= 183970 && c <= 183982)) {
return true;
}
return false;
}
bool IsAccent(char32_t c)
{
// only support part of accent

Просмотреть файл

@ -59,6 +59,12 @@ bool IsCJK(char32_t c);
bool IsAccent(char32_t c);
bool IsSpace(char32_t c);
bool IsPunct(char32_t c);
bool IsControl(char32_t c);
char32_t StripAccent(char32_t c);
uint64_t Hash64(const char* data, size_t n, uint64_t seed);

Просмотреть файл

@ -54,22 +54,20 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
}
// 0x2019 unicode is not punctuation in some Linux platform,
// to be consistent, take it as punctatuation always.
if (tokenize_punctuation_ && (::iswpunct(c) || c == wint_t(0x2019))) {
// to be consistent, take it as punctuation.
if (tokenize_punctuation_ && IsPunct(c)) {
push_current_token_and_clear();
push_single_char_and_clear(c);
continue;
}
// split by space
if (::iswspace(c)) {
if (IsSpace(c)) {
push_current_token_and_clear();
continue;
}
// iscntrl will judge \t\f\n\r as control char
// but it has been filter by isspace(c)
if (remove_control_chars_ && ::iswcntrl(c)) {
if (remove_control_chars_ && IsControl(c)) {
continue;
}

Двоичные данные
test/data/test_bert_tokenizer1.onnx Normal file

Двоичный файл не отображается.

Двоичные данные
test/data/test_segment_extraction.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include "gtest/gtest.h"
#include "ocos.h"
#include "test_kernel.hpp"
TEST(utils, test_bert_tokenizer) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(1);
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {1};
inputs[0].values_string = {"We look forward to welcoming you to our stores. Whether you shop in a store or shop online, our Specialists can help you buy the products you love."};
std::vector<TestValue> outputs(3);
outputs[0].name = "input_ids";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[0].dims = {34};
outputs[0].values_int64 = {101, 1284, 1440, 1977, 1106, 20028, 1128, 1106, 1412, 4822, 119, 13197, 1128, 4130, 1107, 170, 2984, 1137, 4130, 3294, 117, 1412, 25607, 1116, 1169, 1494, 1128, 4417, 1103, 2982, 1128, 1567, 119, 102};
outputs[1].name = "token_type_ids";
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[1].dims = {34};
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
outputs[2].name = "attention_mask";
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[2].dims = {34};
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "test_bert_tokenizer1.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {1};
inputs[0].values_string = {"本想好好的伤感 想放任 但是没泪痕"};
outputs[0].name = "input_ids";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[0].dims = {17};
outputs[0].values_int64 = {101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102};
outputs[1].name = "token_type_ids";
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[1].dims = {17};
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
outputs[2].name = "attention_mask";
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[2].dims = {17};
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {1};
inputs[0].values_string = {"ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~"};
outputs[0].name = "input_ids";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[0].dims = {71};
outputs[0].values_int64 = {101, 13807, 11189, 8101, 27073, 27073, 12738, 11607, 2346, 2346, 2346, 2346, 2346, 2591, 2591, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 106, 108, 109, 110, 111, 113, 2545, 137, 17599, 7301, 4964, 119, 3254, 114, 115, 116, 117, 118, 119, 120, 131, 132, 133, 134, 135, 136, 137, 164, 165, 166, 167, 168, 169, 196, 197, 198, 199, 102};
outputs[1].name = "token_type_ids";
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[1].dims = {71};
outputs[1].values_int64 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
outputs[2].name = "attention_mask";
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[2].dims = {71};
outputs[2].values_int64 = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}

Просмотреть файл

@ -4,24 +4,27 @@ import numpy as np
import transformers
from onnxruntime_extensions import PyOrtFunction, BertTokenizer
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
def _get_test_data_file(*sub_dirs):
test_dir = Path(__file__).parent
return str(test_dir.joinpath(*sub_dirs))
bert_cased_tokenizer = transformers.BertTokenizer(_get_test_data_file('data', 'bert_basic_cased_vocab.txt'), False,
strip_accents=True)
def _run_basic_case(input, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
result = t2stc([input])
expect_result = bert_cased_tokenizer.encode_plus(input)
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
def _run_combined_case(input, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
result = t2stc(input)
expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1])
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
@ -34,17 +37,19 @@ class TestBertTokenizer(unittest.TestCase):
def test_text_to_case1(self):
_run_basic_case(input="Input 'text' must not be empty.",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(
input="ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="本想好好的伤感 想放任 但是没泪痕", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="网 易 云 音 乐",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat is playing toys",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat isnot playing toyssss",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat isnot playing toyssss",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"],
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
if __name__ == "__main__":

Просмотреть файл

@ -0,0 +1,144 @@
import unicodedata
def _is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `char` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `char` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def find_expect_char_in_range(judge_fun, start, end):
result = []
for c in range(start, end):
if judge_fun(chr(c)):
result.append(c)
return result
def find_ranges(nums):
nums = sorted(set(nums))
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e]
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
return list(zip(edges, edges))
def find_expect_category(category_func):
expect_category_set = []
# ASCII
expect_category_set += find_expect_char_in_range(category_func, 0, 0x7F)
# C1 Controls and Latin-1 Supplement
expect_category_set += find_expect_char_in_range(category_func, 0x80, 0xFF)
# Latin Extended-A
expect_category_set += find_expect_char_in_range(category_func, 0x100, 0x17F)
# Latin Extended-B
expect_category_set += find_expect_char_in_range(category_func, 0x180, 0x24F)
# IPA Extensions
expect_category_set += find_expect_char_in_range(category_func, 0x250, 0x2AF)
# Spacing Modifier Letters
expect_category_set += find_expect_char_in_range(category_func, 0x2B0, 0x2FF)
# Combining Diacritical Marks
expect_category_set += find_expect_char_in_range(category_func, 0x300, 0x36F)
# Greek/Coptic
expect_category_set += find_expect_char_in_range(category_func, 0x370, 0x3FF)
# Cyrillic and Cyrillic Supplement
expect_category_set += find_expect_char_in_range(category_func, 0x400, 0x52F)
# General Punctuation
expect_category_set += find_expect_char_in_range(category_func, 0x2000, 0x206F)
# CJK Radicals Supplement
expect_category_set += find_expect_char_in_range(category_func, 0x2E80, 0x2EFF)
# CJK Symbols and Punctuation
expect_category_set += find_expect_char_in_range(category_func, 0x3000, 0x303F)
# CJK
expect_category_set += find_expect_char_in_range(category_func, 0x4E00, 0x9FFF)
expect_category_set += find_expect_char_in_range(category_func, 0x3400, 0x4DBF)
expect_category_set += find_expect_char_in_range(category_func, 0x20000, 0x2A6DF)
expect_category_set += find_expect_char_in_range(category_func, 0x2A700, 0x2B73F)
expect_category_set += find_expect_char_in_range(category_func, 0x2B740, 0x2CEAF)
expect_category_set += find_expect_char_in_range(category_func, 0xF900, 0xFAFF)
expect_category_set += find_expect_char_in_range(category_func, 0x2F800, 0x2FA1F)
return find_ranges(expect_category_set)
def print_range(ranges):
single_set = []
pair_set = []
for r in ranges:
start, end = r
if start == end:
single_set.append(start)
else:
pair_set.append(r)
output = "if ("
for i in range(len(single_set)):
if i != 0:
output += "||"
output += f"c == {single_set[i]}"
output += ") {\n return true;\n}\n\n"
output += "if ("
for i in range(len(pair_set)):
if i != 0:
output += "||"
start, end = pair_set[i]
output += f"(c >= {start} && c <= {end})"
output += ") {\n return true;\n}\n\nreturn false;\n"
print(output)
print("\nis_whitespace:")
print_range(find_expect_category(_is_whitespace))
print("\nis_punctuation:")
print_range(find_expect_category(_is_punctuation))
print("\nis_control:")
print_range(find_expect_category(_is_control))