Initial BertTokenizer offset mapping implementation (#477)

* Initial BertTokenizer offset mapping implementation

* minor change

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Sayan Shaw 2023-07-03 15:17:23 -07:00 коммит произвёл GitHub
Родитель afb3e83df2
Коммит d876f7ff82
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 150 добавлений и 30 удалений

Просмотреть файл

@ -237,7 +237,8 @@ class BertTokenizer(CustomOp):
return [
cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None]),
cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, 2])
]
@classmethod

Просмотреть файл

@ -1,6 +1,9 @@
#include "bert_tokenizer.hpp"
#include <utility>
#include <iostream>
#include <optional>
#include <list>
BertTokenizerVocab::BertTokenizerVocab(std::string_view vocab) : raw_vocab_(vocab) {
auto tokens = SplitString(raw_vocab_, "\r\n", true);
@ -50,7 +53,7 @@ WordpieceTokenizer::WordpieceTokenizer(
unk_token_id_ = vocab_->FindTokenId(unk_token_);
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map) {
std::vector<ustring> result;
ustring token;
for (auto c : text) {
@ -70,12 +73,43 @@ std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
return result;
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map) {
std::vector<ustring> result;
for (const auto& token : tokens) {
GreedySearch(token, result);
}
size_t offset = 0;
OffsetMappingType offset_mapping;
// Add offset mapping for BOS token
offset_mapping.push_back(std::make_pair(0, 0));
for (auto i : result) {
// Handle special cases for offset mapping
size_t idx = 0;
if (idx < std::string(i).size() && std::string(i).at(idx) == '#') {
while (idx < std::string(i).size() && std::string(i).at(idx) == '#') {
idx++;
}
offset--;
offset_mapping.emplace_back(std::make_pair(offset, offset + std::string(i).size() - idx));
offset += (std::string(i).size() - idx) + 1;
} else if (std::string(i).compare("[UNK]") == 0) {
offset_mapping.emplace_back(std::make_pair(offset, offset + 1));
offset += 2;
} else {
offset_mapping.emplace_back(std::make_pair(offset, offset + std::string(i).size()));
offset += std::string(i).size() + 1;
}
}
// Add offset mapping for EOS token
offset_mapping.emplace_back(std::make_pair(0, 0));
// Add offset mappings for input in this instance to list of offset mappings for all inputs
offset_map.emplace_back(offset_mapping);
return result;
}
@ -206,11 +240,11 @@ BertTokenizer::BertTokenizer(
mask_token_id_ = vocab_->FindTokenId(mask_token);
}
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map) {
if (do_basic_tokenize_) {
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text), offset_map);
}
return wordpiece_tokenizer_->Tokenize(text);
return wordpiece_tokenizer_->Tokenize(text, offset_map);
}
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
@ -295,7 +329,8 @@ KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo&
void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2) {
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
// Setup inputs
auto& input_data = input.Data();
@ -304,16 +339,17 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
}
std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids;
std::list<OffsetMappingType> offset_map;
if (input_data.size() == 1) {
std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]), offset_map);
std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
tokenizer_->Truncate(encoded);
input_ids = tokenizer_->AddSpecialToken(encoded);
token_type_ids = tokenizer_->GenerateTypeId(encoded);
} else {
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]), offset_map);
std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]), offset_map);
std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
@ -330,6 +366,21 @@ void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out1);
auto* p_out2 = output2.Allocate(output_dim);
std::copy(attention_mask.begin(), attention_mask.end(), p_out2);
std::vector<int64_t> offset_dim{static_cast<int64_t>(input_ids.size()), 2}; // tuple of offsets for each input id
if (offset_mapping.has_value()) {
auto* offset = (*offset_mapping)->Allocate(offset_dim);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
}
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
@ -338,7 +389,8 @@ KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelI
void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2) {
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
// Setup inputs
auto& input_data = input.Data();
@ -346,8 +398,10 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
}
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
std::list<OffsetMappingType> offset_map;
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]), offset_map);
std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]), offset_map);
std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
@ -362,4 +416,19 @@ void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::copy(attention_mask.begin(), attention_mask.end(), p_out1);
auto* p_out2 = output2.Allocate(outer_dims);
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out2);
std::vector<int64_t> offset_dim{static_cast<int64_t>(input_ids.size()), 2}; // tuple of offsets for each input id
if (offset_mapping.has_value()) {
auto* offset = (*offset_mapping)->Allocate(offset_dim);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
}

Просмотреть файл

@ -10,6 +10,7 @@
#include "basic_tokenizer.hpp"
#include <unordered_map>
#include <list>
class BertTokenizerVocab final {
public:
@ -44,8 +45,9 @@ class WordpieceTokenizer final {
WordpieceTokenizer(
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word = 100);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
private:
@ -64,7 +66,8 @@ class BertTokenizer final {
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
std::vector<ustring> Tokenize(const ustring& text);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
void Truncate(std::vector<int64_t>& ids);
@ -94,7 +97,9 @@ struct KernelBertTokenizer : BaseKernel {
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2);
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
protected:
std::unique_ptr<BertTokenizer> tokenizer_;
@ -102,8 +107,10 @@ struct KernelBertTokenizer : BaseKernel {
struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2);
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
};

Просмотреть файл

@ -3,6 +3,7 @@ import unittest
import numpy as np
import transformers
from onnxruntime_extensions import PyOrtFunction, BertTokenizer, util
from transformers import BertTokenizerFast
bert_cased_tokenizer = transformers.BertTokenizer(
@ -33,9 +34,31 @@ def _run_combined_case(input, vocab_path):
np.testing.assert_array_equal(result[1], expect_result["token_type_ids"])
np.testing.assert_array_equal(result[2], expect_result["attention_mask"])
def _run_basic_with_offset_check(input, vocab_path):
t2stc = PyOrtFunction.from_customop(
BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1
)
result = t2stc([input])
expect_result = bert_cased_tokenizer.encode_plus(input)
np.testing.assert_array_equal(result[0], expect_result["input_ids"])
np.testing.assert_array_equal(result[1], expect_result["token_type_ids"])
np.testing.assert_array_equal(result[2], expect_result["attention_mask"])
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
bert_out = tokenizer(input, return_offsets_mapping=True)
np.testing.assert_array_equal(result[3], bert_out['offset_mapping'])
print("\nTest sentence: " + str(input))
print("HF offset mapping: " + str(bert_out['offset_mapping']))
print("EXT offset mapping: ", end='')
for row in result[3]:
print("(" + str(row[0]) + ", " + str(row[1]) + "), ", end='')
print("\n")
class TestBertTokenizer(unittest.TestCase):
def test_text_to_case1(self):
print("\n\n****** Starting input ids, token type ids, and attention mask tests. ******\n")
_run_basic_case(
input="Input 'text' must not be empty.",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
@ -53,23 +76,43 @@ class TestBertTokenizer(unittest.TestCase):
input="本想好好的伤感 想放任 但是没泪痕",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_case(
input="网 易 云 音 乐",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_case(
input="cat is playing toys",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_case(
input="cat isnot playing toyssss",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_combined_case(
["网 易 云 音 乐", "cat isnot playing toyssss"],
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
print("\n****** Input ids, token type ids, and attention mask tests complete. ******\n\n\n")
print("*** Starting offset mapping tests. ***\n")
_run_basic_with_offset_check(
input="网 易 云 音 乐",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="cat is playing toys",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="cat isnot playing toyssss",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="ah oui on peut parler francais",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="und eigentlich auch deutsch",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="podemos hablar muchos idiomas",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
_run_basic_with_offset_check(
input="",
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
)
print("\n*** Offset mapping tests complete. ***\n")
if __name__ == "__main__":
unittest.main()