ort-customops/test/static_test/test_tokenizer.cc

268 строки
10 KiB
C++
Исходник Постоянная ссылка Ответственный История

Этот файл содержит неоднозначные символы Юникода!

Этот файл содержит неоднозначные символы Юникода, которые могут быть перепутаны с другими в текущей локали. Если это намеренно, можете спокойно проигнорировать это предупреждение. Используйте кнопку Экранировать, чтобы подсветить эти символы.

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "string_utils.h"
#include "wordpiece_tokenizer.hpp"
#include "bert_tokenizer.hpp"
#include <clocale>
#include "tokenizer/basic_tokenizer.hpp"
class LocaleBaseTest : public testing::Test {
public:
// Remember that SetUp() is run immediately before a test starts.
void SetUp() override {
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
default_locale_ = std::locale().name();
std::setlocale(LC_CTYPE, "C");
#else
default_locale_ = std::locale("").name();
std::setlocale(LC_CTYPE, "en_US.UTF-8");
#endif
}
// TearDown() is invoked immediately after a test finishes.
void TearDown() override {
if (!default_locale_.empty()) {
std::setlocale(LC_CTYPE, default_locale_.c_str());
}
}
private:
std::string default_locale_;
};
#if defined(ENABLE_WORDPIECE_TOKENIZER) && defined(ENABLE_BERT_TOKENIZER)
TEST(tokenizer, bert_word_split) {
ustring ind("##");
ustring text("A AAA B BB");
std::vector<std::u32string> words;
KernelWordpieceTokenizer_Split(ind, text, words);
std::vector<std::u32string> expected{ustring("A"), ustring("AAA"), ustring("B"), ustring("BB")};
EXPECT_EQ(expected, words);
text = ustring(" A AAA B BB ");
KernelWordpieceTokenizer_Split(ind, text, words);
EXPECT_EQ(words, expected);
}
std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
std::vector<ustring> vocab_tokens = {
ustring("[UNK]"),
ustring("[CLS]"),
ustring("[SEP]"),
ustring("[PAD]"),
ustring("[MASK]"),
ustring("want"),
ustring("##want"),
ustring("##ed"),
ustring("wa"),
ustring("un"),
ustring("runn"),
ustring("##ing"),
ustring(","),
ustring("low"),
ustring("lowest"),
};
std::unordered_map<std::u32string, int32_t> vocab;
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
vocab[*it] = static_cast<int32_t>(vocab.size());
}
return vocab;
}
std::vector<ustring> ustring_vector_convertor(std::vector<std::string> input) {
std::vector<ustring> result;
for (const auto& str : input) {
result.emplace_back(str);
}
return result;
}
TEST(tokenizer, wordpiece_basic_tokenizer) {
auto vocab = get_vocabulary_basic();
std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
std::vector<ustring> tokens;
std::vector<int32_t> indices;
std::vector<int64_t> rows;
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
// EXPECT_EQ(indices, std::vector<int32_t>({9, 6, 7, 12, 10, 11}));
// EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
}
std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
std::vector<ustring> vocab_tokens = {
ustring("[UNK]"), // 0
ustring("[CLS]"), // 1
ustring("[SEP]"), // 2
ustring("want"), // 3
ustring("##want"), // 4
ustring("##ed"), // 5
ustring("wa"), // 6
ustring("un"), // 7
ustring("runn"), // 8
ustring("##ing"), // 9
};
std::unordered_map<std::u32string, int32_t> vocab;
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
vocab[*it] = static_cast<int32_t>(vocab.size());
}
return vocab;
}
TEST(tokenizer, wordpiece_wordpiece_tokenizer) {
auto vocab = get_vocabulary_wordpiece();
std::vector<int32_t> indices;
std::vector<int64_t> rows;
std::vector<ustring> tokens;
std::vector<ustring> text = {ustring("unwanted running")}; // "un", "##want", "##ed", "runn", "##ing"
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
ustring("runn"), ustring("##ing")}));
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9}));
EXPECT_EQ(rows, std::vector<int64_t>({0, 5}));
text = std::vector<ustring>({ustring("unwantedX running")}); // "[UNK]", "runn", "##ing"
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
ustring("[UNK]"), ustring("runn"), ustring("##ing")}));
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, -1, 8, 9}));
EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
text = std::vector<ustring>({ustring("")}); //
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
EXPECT_EQ(tokens, std::vector<ustring>());
EXPECT_EQ(indices, std::vector<int32_t>());
EXPECT_EQ(rows, std::vector<int64_t>({0, 0}));
}
TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
auto vocab = get_vocabulary_wordpiece();
std::vector<int32_t> indices;
std::vector<int64_t> rows;
std::vector<ustring> tokens;
std::vector<int64_t> existing_indices({0, 2, 3});
std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
existing_indices.data(), existing_indices.size());
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
ustring("runn"), ustring("##ing"),
ustring("runn"), ustring("##ing")}));
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
}
TEST_F(LocaleBaseTest, basic_tokenizer_chinese) {
ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu",
"", "𨖷", "", "𨀐", "", "𧄋", "", "𩑢", "𡗶", "", "", "𧎼", "", "", "", "𢑌", "", "𤨡",
"!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":",
";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
BasicTokenizer tokenizer(true, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST_F(LocaleBaseTest, basic_tokenizer_russia) {
ustring test_case = ustring("A $100,000 price-tag@big>small на русском языке");
std::vector<ustring> expect_result = ustring_vector_convertor({"a", "$", "100", ",", "000", "price", "-", "tag", "@", "big", ">", "small", "на", "русском", "языке"});
BasicTokenizer tokenizer(true, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST_F(LocaleBaseTest, basic_tokenizer) {
ustring test_case = ustring("I mean, youll need something to talk about next Sunday, right?");
std::vector<ustring> expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"});
BasicTokenizer tokenizer(false, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST(tokenizer, truncation_one_input) {
TruncateStrategy truncate("longest_first");
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
auto test_input = init_vector1;
truncate.Truncate(test_input, -1);
EXPECT_EQ(test_input, init_vector1);
test_input = init_vector1;
truncate.Truncate(test_input, 5);
EXPECT_EQ(test_input, std::vector<int64_t>({1, 2, 3, 4, 5}));
test_input = init_vector2;
truncate.Truncate(test_input, 6);
EXPECT_EQ(test_input, init_vector2);
}
TEST(tokenizer, truncation_longest_first) {
TruncateStrategy truncate("longest_first");
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
auto test_input1 = init_vector1;
auto test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, -1);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 15);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 14);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 8);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 9);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 12);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5}));
test_input1 = init_vector2;
test_input2 = init_vector1;
truncate.Truncate(test_input1, test_input2, 12);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
}
TEST(tokenizer, basic_tok_eager) {
std::string test_case = "I mean, youll need something to talk about next Sunday, right?";
std::vector<std::string> expect_result = {"I", "mean", ",", "you", "", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"};
ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"},
std::make_tuple(false, true, true, true, true));
KernelBasicTokenizer tokenizer(dict);
ortc::Tensor<std::string> output;
tokenizer.Compute(test_case, output);
EXPECT_EQ(output.Data(), expect_result);
}
#endif