268 строки
10 KiB
C++
268 строки
10 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||
// Licensed under the MIT License.
|
||
|
||
#include "gtest/gtest.h"
|
||
#include "string_utils.h"
|
||
#include "wordpiece_tokenizer.hpp"
|
||
#include "bert_tokenizer.hpp"
|
||
|
||
#include <clocale>
|
||
#include "tokenizer/basic_tokenizer.hpp"
|
||
|
||
class LocaleBaseTest : public testing::Test {
|
||
public:
|
||
// Remember that SetUp() is run immediately before a test starts.
|
||
void SetUp() override {
|
||
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
|
||
default_locale_ = std::locale().name();
|
||
std::setlocale(LC_CTYPE, "C");
|
||
#else
|
||
default_locale_ = std::locale("").name();
|
||
std::setlocale(LC_CTYPE, "en_US.UTF-8");
|
||
#endif
|
||
}
|
||
// TearDown() is invoked immediately after a test finishes.
|
||
void TearDown() override {
|
||
if (!default_locale_.empty()) {
|
||
std::setlocale(LC_CTYPE, default_locale_.c_str());
|
||
}
|
||
}
|
||
|
||
private:
|
||
std::string default_locale_;
|
||
};
|
||
|
||
#if defined(ENABLE_WORDPIECE_TOKENIZER) && defined(ENABLE_BERT_TOKENIZER)
|
||
|
||
TEST(tokenizer, bert_word_split) {
|
||
ustring ind("##");
|
||
ustring text("A AAA B BB");
|
||
std::vector<std::u32string> words;
|
||
KernelWordpieceTokenizer_Split(ind, text, words);
|
||
std::vector<std::u32string> expected{ustring("A"), ustring("AAA"), ustring("B"), ustring("BB")};
|
||
EXPECT_EQ(expected, words);
|
||
|
||
text = ustring(" A AAA B BB ");
|
||
KernelWordpieceTokenizer_Split(ind, text, words);
|
||
EXPECT_EQ(words, expected);
|
||
}
|
||
|
||
std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
|
||
std::vector<ustring> vocab_tokens = {
|
||
ustring("[UNK]"),
|
||
ustring("[CLS]"),
|
||
ustring("[SEP]"),
|
||
ustring("[PAD]"),
|
||
ustring("[MASK]"),
|
||
ustring("want"),
|
||
ustring("##want"),
|
||
ustring("##ed"),
|
||
ustring("wa"),
|
||
ustring("un"),
|
||
ustring("runn"),
|
||
ustring("##ing"),
|
||
ustring(","),
|
||
ustring("low"),
|
||
ustring("lowest"),
|
||
};
|
||
std::unordered_map<std::u32string, int32_t> vocab;
|
||
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
|
||
vocab[*it] = static_cast<int32_t>(vocab.size());
|
||
}
|
||
return vocab;
|
||
}
|
||
|
||
std::vector<ustring> ustring_vector_convertor(std::vector<std::string> input) {
|
||
std::vector<ustring> result;
|
||
for (const auto& str : input) {
|
||
result.emplace_back(str);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
TEST(tokenizer, wordpiece_basic_tokenizer) {
|
||
auto vocab = get_vocabulary_basic();
|
||
std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
|
||
std::vector<ustring> tokens;
|
||
std::vector<int32_t> indices;
|
||
std::vector<int64_t> rows;
|
||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
|
||
// EXPECT_EQ(indices, std::vector<int32_t>({9, 6, 7, 12, 10, 11}));
|
||
// EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
|
||
}
|
||
|
||
std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
|
||
std::vector<ustring> vocab_tokens = {
|
||
ustring("[UNK]"), // 0
|
||
ustring("[CLS]"), // 1
|
||
ustring("[SEP]"), // 2
|
||
ustring("want"), // 3
|
||
ustring("##want"), // 4
|
||
ustring("##ed"), // 5
|
||
ustring("wa"), // 6
|
||
ustring("un"), // 7
|
||
ustring("runn"), // 8
|
||
ustring("##ing"), // 9
|
||
};
|
||
std::unordered_map<std::u32string, int32_t> vocab;
|
||
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
|
||
vocab[*it] = static_cast<int32_t>(vocab.size());
|
||
}
|
||
return vocab;
|
||
}
|
||
|
||
TEST(tokenizer, wordpiece_wordpiece_tokenizer) {
|
||
auto vocab = get_vocabulary_wordpiece();
|
||
std::vector<int32_t> indices;
|
||
std::vector<int64_t> rows;
|
||
std::vector<ustring> tokens;
|
||
|
||
std::vector<ustring> text = {ustring("unwanted running")}; // "un", "##want", "##ed", "runn", "##ing"
|
||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
|
||
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
|
||
ustring("runn"), ustring("##ing")}));
|
||
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9}));
|
||
EXPECT_EQ(rows, std::vector<int64_t>({0, 5}));
|
||
|
||
text = std::vector<ustring>({ustring("unwantedX running")}); // "[UNK]", "runn", "##ing"
|
||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
|
||
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
|
||
ustring("[UNK]"), ustring("runn"), ustring("##ing")}));
|
||
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, -1, 8, 9}));
|
||
EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
|
||
|
||
text = std::vector<ustring>({ustring("")}); //
|
||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
|
||
EXPECT_EQ(tokens, std::vector<ustring>());
|
||
EXPECT_EQ(indices, std::vector<int32_t>());
|
||
EXPECT_EQ(rows, std::vector<int64_t>({0, 0}));
|
||
}
|
||
|
||
TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
|
||
auto vocab = get_vocabulary_wordpiece();
|
||
std::vector<int32_t> indices;
|
||
std::vector<int64_t> rows;
|
||
std::vector<ustring> tokens;
|
||
|
||
std::vector<int64_t> existing_indices({0, 2, 3});
|
||
std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
|
||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
|
||
existing_indices.data(), existing_indices.size());
|
||
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
|
||
ustring("runn"), ustring("##ing"),
|
||
ustring("runn"), ustring("##ing")}));
|
||
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
|
||
EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
|
||
}
|
||
|
||
TEST_F(LocaleBaseTest, basic_tokenizer_chinese) {
|
||
ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
|
||
std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu",
|
||
"䗓", "𨖷", "虴", "𨀐", "辘", "𧄋", "脟", "𩑢", "𡗶", "镇", "伢", "𧎼", "䪱", "轚", "榶", "𢑌", "㺽", "𤨡",
|
||
"!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":",
|
||
";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
|
||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||
auto result = tokenizer.Tokenize(test_case);
|
||
EXPECT_EQ(result, expect_result);
|
||
}
|
||
|
||
TEST_F(LocaleBaseTest, basic_tokenizer_russia) {
|
||
ustring test_case = ustring("A $100,000 price-tag@big>small на русском языке");
|
||
std::vector<ustring> expect_result = ustring_vector_convertor({"a", "$", "100", ",", "000", "price", "-", "tag", "@", "big", ">", "small", "на", "русском", "языке"});
|
||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||
auto result = tokenizer.Tokenize(test_case);
|
||
EXPECT_EQ(result, expect_result);
|
||
}
|
||
|
||
TEST_F(LocaleBaseTest, basic_tokenizer) {
|
||
ustring test_case = ustring("I mean, you’ll need something to talk about next Sunday, right?");
|
||
std::vector<ustring> expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"});
|
||
BasicTokenizer tokenizer(false, true, true, true, true);
|
||
auto result = tokenizer.Tokenize(test_case);
|
||
EXPECT_EQ(result, expect_result);
|
||
}
|
||
|
||
TEST(tokenizer, truncation_one_input) {
|
||
TruncateStrategy truncate("longest_first");
|
||
|
||
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
|
||
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
|
||
|
||
auto test_input = init_vector1;
|
||
truncate.Truncate(test_input, -1);
|
||
EXPECT_EQ(test_input, init_vector1);
|
||
|
||
test_input = init_vector1;
|
||
truncate.Truncate(test_input, 5);
|
||
EXPECT_EQ(test_input, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||
|
||
test_input = init_vector2;
|
||
truncate.Truncate(test_input, 6);
|
||
EXPECT_EQ(test_input, init_vector2);
|
||
}
|
||
|
||
TEST(tokenizer, truncation_longest_first) {
|
||
TruncateStrategy truncate("longest_first");
|
||
|
||
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
|
||
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
|
||
|
||
auto test_input1 = init_vector1;
|
||
auto test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, -1);
|
||
EXPECT_EQ(test_input1, init_vector1);
|
||
EXPECT_EQ(test_input2, init_vector2);
|
||
|
||
test_input1 = init_vector1;
|
||
test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, 15);
|
||
EXPECT_EQ(test_input1, init_vector1);
|
||
EXPECT_EQ(test_input2, init_vector2);
|
||
|
||
test_input1 = init_vector1;
|
||
test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, 14);
|
||
EXPECT_EQ(test_input1, init_vector1);
|
||
EXPECT_EQ(test_input2, init_vector2);
|
||
|
||
test_input1 = init_vector1;
|
||
test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, 8);
|
||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4}));
|
||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
|
||
|
||
test_input1 = init_vector1;
|
||
test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, 9);
|
||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
|
||
|
||
test_input1 = init_vector1;
|
||
test_input2 = init_vector2;
|
||
truncate.Truncate(test_input1, test_input2, 12);
|
||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
|
||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||
|
||
test_input1 = init_vector2;
|
||
test_input2 = init_vector1;
|
||
truncate.Truncate(test_input1, test_input2, 12);
|
||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
|
||
}
|
||
|
||
TEST(tokenizer, basic_tok_eager) {
|
||
std::string test_case = "I mean, you’ll need something to talk about next Sunday, right?";
|
||
std::vector<std::string> expect_result = {"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"};
|
||
|
||
ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"},
|
||
std::make_tuple(false, true, true, true, true));
|
||
|
||
KernelBasicTokenizer tokenizer(dict);
|
||
|
||
ortc::Tensor<std::string> output;
|
||
tokenizer.Compute(test_case, output);
|
||
EXPECT_EQ(output.Data(), expect_result);
|
||
}
|
||
|
||
#endif
|