From d1148aea4e278ec0e781d93e6b91570ab7edc7dd Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Sat, 4 Nov 2023 22:56:26 -0700 Subject: [PATCH] Support 'added_token' attribute for BPE tokenizer and some code refactoring. (#591) * Fix CodeGenTokenizer issues and the related code refactoring. * refactor the trie-tree * temp check-ins * code complete * correctness fixing * Update _hf_cvt.py * more test cases fixing * more refinement * linux crash fixing * Update test_autotokenizer.py --- base/ustring.h | 3 +- includes/onnxruntime_customop.hpp | 23 +++-- onnxruntime_extensions/_hf_cvt.py | 60 +++++------- operators/tokenizer/bpe_decoder.hpp | 57 ++++++----- operators/tokenizer/bpe_kernels.cc | 88 +++++++++-------- operators/tokenizer/bpe_kernels.h | 53 ++++++----- operators/tokenizer/bpe_tokenizer.hpp | 113 ++++++++++++++++------ operators/tokenizer/bpe_utils.hpp | 45 ++++++--- operators/tokenizer/tokenizers.cc | 8 +- operators/tokenizer/trie_tokenizer.hpp | 50 ++-------- operators/tokenizer/trietree.hpp | 126 +++++++++++++++++++++++++ test/test_autotokenizer.py | 8 +- 12 files changed, 411 insertions(+), 223 deletions(-) create mode 100644 operators/tokenizer/trietree.hpp diff --git a/base/ustring.h b/base/ustring.h index a6d42ab4..58a80060 100644 --- a/base/ustring.h +++ b/base/ustring.h @@ -2,7 +2,6 @@ // Licensed under the MIT License. #pragma once -#include "ocos.h" #include #include @@ -85,7 +84,7 @@ class ustring : public std::u32string { using u32string = std::u32string; static u32string FromUTF8(const std::string_view& utf8) { u32string ucs32; - ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation. + ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation. for (size_t i = 0; i < utf8.size();) { char32_t codepoint = 0; if ((utf8[i] & 0x80) == 0) { diff --git a/includes/onnxruntime_customop.hpp b/includes/onnxruntime_customop.hpp index 48b68e12..e67876a8 100644 --- a/includes/onnxruntime_customop.hpp +++ b/includes/onnxruntime_customop.hpp @@ -30,7 +30,7 @@ class API { // To use ONNX C ABI in a way like OrtW::API::CreateStatus. public: static API& instance(const OrtApi* ort_api = nullptr) noexcept { - static API self(*ort_api); + static API self(ort_api); return self; } @@ -54,15 +54,15 @@ class API { return &api_; } - API(const OrtApi& api) : api_(api) { - if (&api == nullptr) { + API(const OrtApi* api) : api_(*api) { + if (api == nullptr) { ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION); } } + const OrtApi& api_; }; - template <> inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept { return instance()->KernelInfoGetAttribute_int64(&info, name, &value); @@ -107,21 +107,32 @@ inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) { return API::CreateStatus(code, msg); } +inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) { + return API::CreateStatus(code, msg.c_str()); +} + inline void ReleaseStatus(OrtStatusPtr& status) { API::ReleaseStatus(status); status = nullptr; } - } // namespace OrtW +#define ORTX_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + namespace Ort { namespace Custom { #ifdef USE_CUDA /////////////////////////////////////////////////////////////////////////// // TODO: include the definition from the header file in ONNXRuntime -struct CudaContext {}; +struct CudaContext {}; #endif // USE_CUDA diff --git a/onnxruntime_extensions/_hf_cvt.py b/onnxruntime_extensions/_hf_cvt.py index a03302f3..b023ca48 100644 --- a/onnxruntime_extensions/_hf_cvt.py +++ b/onnxruntime_extensions/_hf_cvt.py @@ -23,36 +23,28 @@ class HFTokenizerConverter(CustomOpConverter): def __init__(self, tokenizer): self.tokenizer = tokenizer - def bpe_tokenizer(self, **kwargs): - hf_gpt2_tokenizer = self.tokenizer - attrs = None + @staticmethod + def convert_bpe_vocab(hf_tokenizer): + attrs = {'vocab': json.dumps( + hf_tokenizer.encoder, separators=(',', ':'))} + if hf_tokenizer.added_tokens_encoder: + # ids = sorted(hf_tokenizer.added_tokens_encoder.values()) + # if not ids == list(range(min(ids), max(ids) + 1)): + # raise RuntimeError(f"{hf_tokenizer.__name__}: the ids in added_tokens_encoder are not consecutive") + token_map = [f"{_k}={_v}" for _k, _v in hf_tokenizer.added_tokens_encoder.items()] + attrs.update({"added_token": "\n".join(token_map)}) - if type(self.tokenizer).__name__.endswith('Fast'): - raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).') - elif(self.tokenizer.name_or_path.endswith('gpt-4')): - # Fill vocab gap for GPT4Tokenizer to create continuous domain - vocab_dict = hf_gpt2_tokenizer.encoder - partial_values = list(vocab_dict.values()) - - max_vocab = partial_values[-1] - all_values = np.arange(max_vocab + 1) - - missing_values = set(all_values) - set(partial_values) - - for v in missing_values: - vocab_dict[str(uuid.uuid4())] = int(v) - - vocab_dict = dict(sorted(vocab_dict.items(), key=lambda item: item[1])) - - attrs = {'vocab': json.dumps( - vocab_dict, separators=(',', ':'))} - else: - attrs = {'vocab': json.dumps( - hf_gpt2_tokenizer.encoder, separators=(',', ':'))} - - sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()} + sorted_merges = {v_: k_ for k_, v_ in hf_tokenizer.bpe_ranks.items()} attrs['merges'] = '\n'.join("{} {}".format( *sorted_merges[n_]) for n_ in range(len(sorted_merges))) + return attrs + + def bpe_tokenizer(self, **kwargs): + hf_gpt2_tokenizer = self.tokenizer + if type(self.tokenizer).__name__.endswith('Fast'): + raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).') + + attrs = self.convert_bpe_vocab(hf_gpt2_tokenizer) attrs.update(**kwargs) return attrs @@ -101,12 +93,7 @@ class HFTokenizerConverter(CustomOpConverter): if type(self.tokenizer).__name__.endswith('Fast'): raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).') - attrs = {'vocab': json.dumps( - hf_clip_tokenizer.encoder, separators=(',', ':'))} - sorted_merges = {v_: k_ for k_, - v_ in hf_clip_tokenizer.bpe_ranks.items()} - attrs['merges'] = '\n'.join("{} {}".format( - *sorted_merges[n_]) for n_ in range(len(sorted_merges))) + attrs = self.convert_bpe_vocab(hf_clip_tokenizer) attrs.update(**kwargs) return attrs @@ -116,12 +103,7 @@ class HFTokenizerConverter(CustomOpConverter): if type(self.tokenizer).__name__.endswith('Fast'): raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).') - attrs = {'vocab': json.dumps( - hf_roberta_tokenizer.encoder, separators=(',', ':'))} - sorted_merges = {v_: k_ for k_, - v_ in hf_roberta_tokenizer.bpe_ranks.items()} - attrs['merges'] = '\n'.join("{} {}".format( - *sorted_merges[n_]) for n_ in range(len(sorted_merges))) + attrs = self.convert_bpe_vocab(hf_roberta_tokenizer) attrs.update(**kwargs) return attrs diff --git a/operators/tokenizer/bpe_decoder.hpp b/operators/tokenizer/bpe_decoder.hpp index 5c0aa80d..300e83a2 100644 --- a/operators/tokenizer/bpe_decoder.hpp +++ b/operators/tokenizer/bpe_decoder.hpp @@ -16,19 +16,27 @@ #include #include - -struct KernelBpeDecoder : public BaseKernel { +struct KernelBpeDecoder { public: - KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - std::string vocab = ort_.KernelInfoGetAttribute(&info, "id_vocab"); - if (vocab.empty()) { - ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT); + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + // note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status; + std::string vocab; + OrtStatusPtr status = OrtW::GetOpAttribute(info, "id_vocab", vocab); + if (status != nullptr || vocab.empty()) { + if (status == nullptr) { + status = OrtW::CreateStatus("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT); + } + return status; } BuildIdVocab(vocab); - std::string byte_decoder = ort_.KernelInfoGetAttribute(&info, "byte_decoder"); - if (byte_decoder.empty()) { - ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT); + std::string byte_decoder; + status = OrtW::GetOpAttribute(info, "byte_decoder", byte_decoder); + if (status != nullptr || byte_decoder.empty()) { + if (status == nullptr) { + status = OrtW::CreateStatus("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT); + } + return status; } else { auto um = ParseId2String(byte_decoder); std::transform(um.begin(), um.end(), @@ -37,13 +45,15 @@ struct KernelBpeDecoder : public BaseKernel { ort_extensions::narrow(std::stoul(p.second))); }); } - std::string added_tokens = TryToGetAttributeWithDefault("added_tokens", ""); + std::string added_tokens; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens)); if (!added_tokens.empty()) { auto um = ParseId2String(added_tokens); added_tokens_ = std::map(um.begin(), um.end()); } - std::string all_special_ids = TryToGetAttributeWithDefault("all_special_ids", ""); + std::string all_special_ids; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids)); if (!all_special_ids.empty()) { auto um = ParseId2String(all_special_ids); std::transform(um.begin(), um.end(), @@ -51,12 +61,14 @@ struct KernelBpeDecoder : public BaseKernel { [](const auto& p) { return p.first; }); } - en_normalization_ = TryToGetAttributeWithDefault("en_normalization", 0); - skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", 0); - whitespace_token_ = TryToGetAttributeWithDefault("whitespace_token", 0); - bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>")); - eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>")); - unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>")); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_)); + + return status; } std::unordered_map ParseId2String(const std::string& s_attr) { @@ -102,8 +114,8 @@ struct KernelBpeDecoder : public BaseKernel { arr_vocab_.shrink_to_fit(); } - void Compute(const ortc::Tensor& ids, - ortc::Tensor& output) const { + OrtStatusPtr Compute(const ortc::Tensor& ids, + ortc::Tensor& output) const { const int64_t* p_ids = ids.Data(); const auto& ids_dim = ids.Shape(); std::vector output_dim = {1}; @@ -168,12 +180,13 @@ struct KernelBpeDecoder : public BaseKernel { p_ids += seq_len; } output.SetStringOutput(decoded_strings, output_dim); + return nullptr; } private: - std::string bos_token_; - std::string eos_token_; - std::string unk_token_; + std::string bos_token_{"<|endoftext|>"}; + std::string eos_token_{"<|endoftext|>"}; + std::string unk_token_{"<|endoftext|>"}; // Since ORT API doesn't support boolean type in ONNX node attribute, // all flag attributes here are defined as int64 type to be more explicit. diff --git a/operators/tokenizer/bpe_kernels.cc b/operators/tokenizer/bpe_kernels.cc index 4b2f1b82..06dd1e2a 100644 --- a/operators/tokenizer/bpe_kernels.cc +++ b/operators/tokenizer/bpe_kernels.cc @@ -6,6 +6,8 @@ #include +using namespace ort_extensions; + std::string BpeModelConf::GetSpecialTokens() const { std::string special_tokens = unk_token_; // unk_token_ is required auto add_token = [](std::string& sp, const char* tok) { @@ -87,43 +89,53 @@ ustring RemoveConsecutiveSpaces(const ustring& input) { return result; } -KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf) - : BaseKernel(api, info), - bpe_conf_(conf) { - std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab"); +KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf) + : bpe_conf_(conf){}; + +OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + // note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status; + std::string vocab; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab)); if (vocab.empty()) { - ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); + return OrtW::CreateStatus("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); } - std::string merges = ort_.KernelInfoGetAttribute(&info, "merges"); + std::string merges; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges)); if (merges.empty()) { - ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT); - } - - if (!TryToGetAttribute("padding_length", padding_length_)) { - padding_length_ = -1; + return OrtW::CreateStatus("merges shouldn't be empty.", ORT_INVALID_ARGUMENT); } + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_)); if (padding_length_ != -1 && padding_length_ <= 0) { - ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT); + return OrtW::CreateStatus("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT); } std::stringstream vocabu_stream(vocab); std::stringstream merges_stream(merges); bbpe_tokenizer_ = std::make_unique(); - bbpe_tokenizer_->Load(vocabu_stream, merges_stream, conf.unk_token_, conf.GetSpecialTokens().c_str()); + auto status = bbpe_tokenizer_->Load(vocabu_stream, merges_stream, bpe_conf_.unk_token_, bpe_conf_.GetSpecialTokens().c_str()); + if (status != nullptr) { + return status; + } + + std::string added_token; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token)); + ORTX_RETURN_IF_ERROR(bbpe_tokenizer_->LoadAddedTokens(added_token.c_str())); // TODO: need to check if the special token ids are the same as the ones in HFTokenizer - unk_token_id_ = bbpe_tokenizer_->GetTokenId(conf.unk_token_); - if (conf.bos_token_ != nullptr) { - bos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.bos_token_); + unk_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.unk_token_); + if (bpe_conf_.bos_token_ != nullptr) { + bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.bos_token_); } - if (conf.eos_token_ != nullptr) { - eos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.eos_token_); + if (bpe_conf_.eos_token_ != nullptr) { + eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.eos_token_); } - if (conf.pad_token_ != nullptr) { - pad_token_id_ = bbpe_tokenizer_->GetTokenId(conf.pad_token_); + if (bpe_conf_.pad_token_ != nullptr) { + pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.pad_token_); } + + return nullptr; } std::vector KernelBpeTokenizer::Tokenize(ustring& input, @@ -171,21 +183,20 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, } // Parse input - auto special_token_split_res = bbpe_tokenizer_->SplitBySpecialTokens(input); - TokenWithRegularExp regcmp; + auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input); + bpe::TokenWithRegularExp regcmp; for (auto& seg_id : special_token_split_res) { if (static_cast(res.size()) >= max_length) break; - if (seg_id.second != -1) { + if (seg_id.second != bpe::kInvalidTokenId) { res.push_back(seg_id.second); continue; } - auto cur_input = std::move(seg_id.first); // Note: keep ptr to make sure the string_view is valid in the following process - const char32_t* ptr = cur_input.c_str(); - regcmp.Set(ptr); + std::u32string str(seg_id.first); + regcmp.Set(str.c_str()); size_t offset = 0; OffsetMappingType offset_mapping; @@ -199,7 +210,7 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, while (static_cast(res.size()) < max_length) { auto [b, tok] = regcmp.GetNextToken(); - + if (!b) break; std::string utf8_token = std::string(ustring(tok)); @@ -271,13 +282,14 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, // Add EOS token to result res.push_back(eos_token_id_); } + return res; } -void KernelBpeTokenizer::Compute(const ortc::Tensor& input, - ortc::Tensor& tokenize_output, - std::optional*> attention_mask, - std::optional*> offset_mapping) const { +OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor& input, + ortc::Tensor& tokenize_output, + std::optional*> attention_mask, + std::optional*> offset_mapping) const { // Setup inputs std::vector str_input{input.Data()}; std::list offset_map; @@ -356,11 +368,13 @@ void KernelBpeTokenizer::Compute(const ortc::Tensor& input, idx++; } } + + return nullptr; } static const auto kGPT2Confinguration = BpeModelConf(); -GPT2Tokenizer::GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info) - : KernelBpeTokenizer(api, info, kGPT2Confinguration) {} +GPT2Tokenizer::GPT2Tokenizer() + : KernelBpeTokenizer(kGPT2Confinguration) {} static const auto kRobertaConfiguration = BpeModelConf{ BpeModelConf::kModel_Roberta, // name @@ -369,8 +383,8 @@ static const auto kRobertaConfiguration = BpeModelConf{ "", // eos_token ""}; // pad_token -RobertaTokenizer::RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info) - : KernelBpeTokenizer(api, info, kRobertaConfiguration) {} +RobertaTokenizer::RobertaTokenizer() + : KernelBpeTokenizer(kRobertaConfiguration) {} static const auto kCLIPConfiguration = BpeModelConf{ BpeModelConf::kModel_CLIP, // name @@ -379,5 +393,5 @@ static const auto kCLIPConfiguration = BpeModelConf{ "<|endoftext|>", // eos_token "<|endoftext|>"}; // pad_token -CLIPTokenizer::CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info) - : KernelBpeTokenizer(api, info, kCLIPConfiguration) {} +CLIPTokenizer::CLIPTokenizer() + : KernelBpeTokenizer(kCLIPConfiguration) {} diff --git a/operators/tokenizer/bpe_kernels.h b/operators/tokenizer/bpe_kernels.h index cbfc160e..94c2f4c6 100644 --- a/operators/tokenizer/bpe_kernels.h +++ b/operators/tokenizer/bpe_kernels.h @@ -23,15 +23,18 @@ struct BpeModelConf { std::string GetSpecialTokens() const; }; +namespace ort_extensions { class BpeModel; +} -struct KernelBpeTokenizer : BaseKernel { - KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf); +struct KernelBpeTokenizer { + KernelBpeTokenizer(const BpeModelConf& conf); + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info); - void Compute(const ortc::Tensor& input, - ortc::Tensor& tokenize_output, - std::optional*> attention_mask, - std::optional*> offset_mapping) const; + OrtStatusPtr Compute(const ortc::Tensor& input, + ortc::Tensor& tokenize_output, + std::optional*> attention_mask, + std::optional*> offset_mapping) const; const char* ModelName() const { return bpe_conf_.name_; } @@ -41,10 +44,12 @@ struct KernelBpeTokenizer : BaseKernel { int64_t max_length, bool compute_offset_mapping, std::list& offset_map) const; - int64_t padding_length_; - std::unique_ptr bbpe_tokenizer_; - const BpeModelConf& bpe_conf_; + private: + const BpeModelConf& bpe_conf_; + std::unique_ptr bbpe_tokenizer_; + + int64_t padding_length_ = -1; uint32_t unk_token_id_{}; uint32_t bos_token_id_{}; uint32_t eos_token_id_{}; @@ -52,34 +57,34 @@ struct KernelBpeTokenizer : BaseKernel { }; struct GPT2Tokenizer : KernelBpeTokenizer { - GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info); + GPT2Tokenizer(); // required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler. - void Compute(const ortc::Tensor& input, - ortc::Tensor& tokenize_output, - std::optional*> attention_mask, - std::optional*> offset_mapping) const { + OrtStatusPtr Compute(const ortc::Tensor& input, + ortc::Tensor& tokenize_output, + std::optional*> attention_mask, + std::optional*> offset_mapping) const { return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping); } }; struct RobertaTokenizer : KernelBpeTokenizer { - RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info); + RobertaTokenizer(); // required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler. - void Compute(const ortc::Tensor& input, - ortc::Tensor& tokenize_output, - std::optional*> attention_mask, - std::optional*> offset_mapping) const { + OrtStatusPtr Compute(const ortc::Tensor& input, + ortc::Tensor& tokenize_output, + std::optional*> attention_mask, + std::optional*> offset_mapping) const { return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping); } }; struct CLIPTokenizer : KernelBpeTokenizer { - CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info); + CLIPTokenizer(); // required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler. - void Compute(const ortc::Tensor& input, - ortc::Tensor& tokenize_output, - std::optional*> attention_mask, - std::optional*> offset_mapping) const { + OrtStatusPtr Compute(const ortc::Tensor& input, + ortc::Tensor& tokenize_output, + std::optional*> attention_mask, + std::optional*> offset_mapping) const { return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping); } }; diff --git a/operators/tokenizer/bpe_tokenizer.hpp b/operators/tokenizer/bpe_tokenizer.hpp index 6d0bbf2e..c6b3db28 100644 --- a/operators/tokenizer/bpe_tokenizer.hpp +++ b/operators/tokenizer/bpe_tokenizer.hpp @@ -12,18 +12,23 @@ #include #include #include +#include +#include #include "nlohmann/json.hpp" #include "bpe_utils.hpp" +#include "trietree.hpp" + +namespace ort_extensions { class BpeModel { public: BpeModel() = default; - void Load(std::istream& vocab_stream, - std::istream& merges_stream, - const char* unk_token, - const char* special_tokens) { + OrtStatusPtr Load(std::istream& vocab_stream, + std::istream& merges_stream, + const char* unk_token, + const char* special_tokens) { nlohmann::json tok_json; vocab_stream >> tok_json; vocab_map_ = std::move(tok_json.get>()); @@ -34,6 +39,7 @@ class BpeModel { } else { auto id = ort_extensions::narrow(vocab_map_.size()); vocab_map_[unk_token] = id; + unk_id_ = id; } CreateByteEncoder(); @@ -46,7 +52,7 @@ class BpeModel { if ((line[0] == '#') && (index == 0)) continue; auto pos = line.find(' '); if (pos == std::string::npos) { - ORTX_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT); + return OrtW::CreateStatus("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT); } std::string w1 = line.substr(0, pos); std::string w2 = line.substr(pos + 1); @@ -54,9 +60,9 @@ class BpeModel { if (w2.find("") != std::string::npos || w1.find("") != std::string::npos) { token_length -= 4; } - auto iw1 = GetVocabIndex(w1); - auto iw2 = GetVocabIndex(w2); - auto iww = GetVocabIndex(w1 + w2); + auto iw1 = GetTokenId(w1); + auto iw2 = GetTokenId(w2); + auto iww = GetTokenId(w1 + w2); BpeNode value{iww, index++, token_length}; bpe_rank_[GetRankKey(iw1, iw2)] = value; } @@ -80,8 +86,62 @@ class BpeModel { id2token_map_.resize(vocab_map_.size()); for (const auto& [t, i] : vocab_map_) { + if (i > static_cast(std::numeric_limits::max())) { + continue; // safe purpose. + } + if (i > id2token_map_.size()) { + id2token_map_.resize(i + 1); + } id2token_map_[i] = t; } + + return nullptr; + } + + OrtStatusPtr LoadAddedTokens(const char* added_tokens) { + int id = bpe::kInvalidTokenId; + std::istringstream strm_tokens(added_tokens); + std::string line; + while (!strm_tokens.eof()) { + std::getline(strm_tokens, line); + line.erase(std::remove(line.begin(), line.end(), '\r'), line.end()); + if (line.empty()) continue; + // seperate the key and value by = + auto pos = line.rfind("="); + if (pos == std::string::npos) { + return OrtW::CreateStatus("Error on parse a added_token line: " + line, ORT_INVALID_ARGUMENT); + } + auto token = line.substr(0, pos); + auto id_str = line.substr(pos + 1); // 1 is the length of "=" + auto [ptr, ec] = std::from_chars(id_str.data(), id_str.data() + id_str.length(), id); + if (ec != std::errc()) { + return OrtW::CreateStatus("Cannot convert to an integer from " + id_str, ORT_INVALID_ARGUMENT); + } + + added_tokens_.Add(ustring(token), 0, std::make_optional(id)); + } + + return nullptr; + } + + // REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52 + bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const { + // split by added tokens + bpe::TokenPairs added_result; + bpe::TokenPairs final_result; + added_tokens_.Split(input, added_result); + for (const auto& [token, id] : added_result) { + if (id != bpe::kInvalidTokenId) { + final_result.emplace_back(token, id); + } else { + auto special_result = special_tokens_.SplitBySpecialTokens(token); + for (const auto& [token, id] : special_result) { + final_result.emplace_back(token, id); + } + } + } + + return final_result; } void bpe(std::list>& vals) const { @@ -94,9 +154,15 @@ class BpeModel { for (auto it = vals.begin(); it != vals.end(); ++it) { auto it2 = it; ++it2; - if (it2 == vals.end()) break; + if (it2 == vals.end()) { + break; + } + auto map_it = bpe_rank_.find(GetRankKey(it->first, it2->first)); - if (map_it == bpe_rank_.end()) continue; + if (map_it == bpe_rank_.end()) { + continue; + } + if (minval > map_it->second.value) { ori_id1 = it->first; ori_id2 = it2->first; @@ -105,7 +171,10 @@ class BpeModel { aim_id = map_it->second.id; } } - if (pos_it == vals.end()) break; + + if (pos_it == vals.end()) { + break; + } token_length = pos_it->second; pos_it = vals.erase(pos_it); @@ -129,11 +198,6 @@ class BpeModel { return byte_encoder_; } - auto SplitBySpecialTokens(const ustring& input) const { - return special_tokens_.SplitBySpecialTokens(input); - } - - // Returns token if key was found in vocab, and unk_id_ otherwise uint32_t GetTokenId(const std::string& key) { auto it = vocab_map_.find(key); if (it != end(vocab_map_)) { @@ -163,21 +227,13 @@ class BpeModel { ) */ if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) { - byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(index++)); + byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(index++)); } else { - byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(i)); + byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(i)); } } } - uint32_t GetVocabIndex(const std::string& str) { - auto it = vocab_map_.find(str); - if (it == vocab_map_.end()) { - ORTX_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT); - } - return it->second; - } - private: std::map bpe_rank_; @@ -186,5 +242,8 @@ class BpeModel { std::vector id2token_map_; uint32_t unk_id_ = std::numeric_limits::max(); - SpecialTokenMap special_tokens_; + bpe::SpecialTokenMap special_tokens_; + TrieTree added_tokens_; }; + +} // namespace ort_extensions diff --git a/operators/tokenizer/bpe_utils.hpp b/operators/tokenizer/bpe_utils.hpp index 217cdd41..1d35b86e 100644 --- a/operators/tokenizer/bpe_utils.hpp +++ b/operators/tokenizer/bpe_utils.hpp @@ -6,35 +6,43 @@ #include "ocos.h" #include "narrow.h" +#include #include #include "ustring.h" #include "unicode.h" +namespace ort_extensions { +namespace bpe { + +using TokenPairs = std::vector>; +using u32string_view = std::u32string_view; + +constexpr int kInvalidTokenId = -1; + class SpecialTokenMap { public: void Add(ustring p_str, int p_id) { auto it = token_map_.find(p_str); if (it != token_map_.end()) { - if (it->second != p_id) { - ORTX_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT); - } + assert(it->second == p_id && "Duplicate special tokens."); } else { token_map_[p_str] = p_id; token_list_.push_back(SpecialTokenInfo(std::move(p_str), p_id)); } } - std::vector> SplitBySpecialTokens(ustring input) const { - std::vector> res; - res.emplace_back(std::move(input), -1); + TokenPairs SplitBySpecialTokens(const std::u32string_view& input) const { + TokenPairs res; + res.emplace_back(input, kInvalidTokenId); for (const auto& st : token_list_) { - std::vector> new_split_res; + TokenPairs new_split_res; for (auto& str : res) { - if (str.second != -1) { - new_split_res.push_back(std::move(str)); + if (str.second != kInvalidTokenId) { + new_split_res.emplace_back(str); continue; } + auto it = str.first.begin(); size_t search_pos = 0; while (it != str.first.end()) { @@ -46,21 +54,27 @@ class SpecialTokenMap { std::boyer_moore_searcher(st.str.begin(), st.str.end())); #endif if (search_it == str.first.end()) { - new_split_res.emplace_back(str.first.substr(search_pos), -1); + new_split_res.emplace_back(u32string_view( + str.first.data() + search_pos, str.first.size() - search_pos), + kInvalidTokenId); break; } + auto prefixLen = search_it - it; if (prefixLen != 0) { - new_split_res.emplace_back(str.first.substr(search_pos, prefixLen), -1); + new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, prefixLen), kInvalidTokenId); search_pos += prefixLen; } - new_split_res.emplace_back(str.first.substr(search_pos, st.str.size()), st.id); + + new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, st.str.size()), st.id); it = search_it + st.str.size(); search_pos += st.str.size(); } } + std::swap(new_split_res, res); } + return res; } @@ -101,7 +115,6 @@ class TokenWithRegularExp { private: std::u32string_view TryMatch() { - // python pattern: // 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ @@ -198,8 +211,7 @@ class TokenWithRegularExp { for (; i < m_text.size(); ++i) { if (!IsZ(m_text[i])) break; } - if ((i > 1) && (i != m_text.size())) //\s+(?!\S) - { + if ((i > 1) && (i != m_text.size())) { //\s+(?!\S) i--; std::u32string_view res = m_text.substr(0, i); m_text = m_text.substr(i); @@ -230,3 +242,6 @@ class TokenWithRegularExp { private: std::u32string_view m_text; }; + +} // namespace bpe +} // namespace ort_extensions diff --git a/operators/tokenizer/tokenizers.cc b/operators/tokenizer/tokenizers.cc index 66d647a3..aaf23b87 100644 --- a/operators/tokenizer/tokenizers.cc +++ b/operators/tokenizer/tokenizers.cc @@ -35,10 +35,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& { static OrtOpLoader op_loader( #ifdef ENABLE_GPT2_TOKENIZER - CustomCpuStruct("GPT2Tokenizer", GPT2Tokenizer), - CustomCpuStruct("CLIPTokenizer", CLIPTokenizer), - CustomCpuStruct("RobertaTokenizer", RobertaTokenizer), - CustomCpuStruct("BpeDecoder", KernelBpeDecoder), + CustomCpuStructV2("GPT2Tokenizer", GPT2Tokenizer), + CustomCpuStructV2("CLIPTokenizer", CLIPTokenizer), + CustomCpuStructV2("RobertaTokenizer", RobertaTokenizer), + CustomCpuStructV2("BpeDecoder", KernelBpeDecoder), #endif #ifdef ENABLE_SPM_TOKENIZER diff --git a/operators/tokenizer/trie_tokenizer.hpp b/operators/tokenizer/trie_tokenizer.hpp index 878c3906..79fb3b48 100644 --- a/operators/tokenizer/trie_tokenizer.hpp +++ b/operators/tokenizer/trie_tokenizer.hpp @@ -14,66 +14,32 @@ #include #include "unescape.h" +#include "trietree.hpp" // This Trie Tree is C++ implementation of // https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/rwkv_tokenizer.py // Perf optimized by leveraging C++ features, but the algorithm is the same. -class TrieTree { +class RWKVTrieTree : public ort_extensions::TrieTree { public: static constexpr int kMaxTokenLength_ = 128; - TrieTree(unsigned char ch = 0) : ch_(ch), to_(256) {} + RWKVTrieTree(char ch = 0) : TrieTree(ch) {} + // keep the same function for source code understanding. void add(const std::string& key, int idx = 0, std::optional value = std::optional()) { - if (idx == key.length()) { - if (!value) { - value = key[0]; - } - value_ = value; - return; - } - - unsigned char ch = static_cast(key[idx]); - if (to_[ch] == nullptr) { - to_[ch] = std::make_unique(ch); - } - to_[ch]->add(key, idx + 1, value); + Add(key, idx, value); } int find_longest(const std::string& key, size_t& idx) { - const TrieTree* u = this; - unsigned char ch = key[idx]; - - int tok_id = 0; - size_t idx_end = idx; - while (u->to_[ch]) { - u = u->to_[ch].get(); - idx += 1; - if (u->value_) { - tok_id = *u->value_; - idx_end = idx; - } - if (idx == key.length()) { - break; - } - ch = key[idx]; - } - - idx = idx_end; - return tok_id; + return FindLongest(key, idx); } - - private: - std::vector> to_; - std::optional value_; - unsigned char ch_; }; class TrieTokenizer { private: std::map idx2token; - TrieTree root; + RWKVTrieTree root; public: TrieTokenizer(const std::string& text_tokens) { @@ -210,7 +176,7 @@ struct KernelTrieDetokenizer : public BaseKernel { if (ustring::ValidateUTF8(raw_string)) { output[n] = raw_string; } else { - output[n] = "\ufffd"; // bad utf-8 string + output[n] = "\ufffd"; // bad utf-8 string failed = true; } } diff --git a/operators/tokenizer/trietree.hpp b/operators/tokenizer/trietree.hpp new file mode 100644 index 00000000..85b1e722 --- /dev/null +++ b/operators/tokenizer/trietree.hpp @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "ocos.h" +#include "narrow.h" + +#include +#include +#include +#include +#include + +namespace ort_extensions { + +template +class TrieTree { + public: + static constexpr int kMaxTokenLength_ = 128; + + TrieTree(CharT ch = 0, ValueT invalid_id = -1) : ch_(ch), invalid_id_(invalid_id) {} + + void Add(const std::basic_string& key, int idx = 0, + const std::optional& value = std::nullopt) noexcept { + if (idx == key.length()) { + if (!value) { + value_ = std::make_optional(narrow(key[0])); + } else { + value_ = value; + } + } else { + auto ch = key[idx]; + if (to_.count(ch) == 0) { + to_[ch] = std::make_unique(ch); + } + to_[ch]->Add(key, idx + 1, value); + } + } + + ValueT FindLongest(const std::basic_string& key, size_t& idx) const noexcept { + const TrieTree* u = this; + CharT ch = key[idx]; + + ValueT tok_id = invalid_id_; + size_t idx_end = idx; + while (u->to_.count(ch)) { + u = u->to_.at(ch).get(); + idx += 1; + if (u->value_) { + tok_id = *u->value_; + idx_end = idx; + } + if (idx == key.length()) { + break; + } + ch = key[idx]; + } + + idx = idx_end; + return tok_id; + } + + int Split(const std::basic_string& input, + std::vector, ValueT>>& tokens) const noexcept { + size_t seg_idx = 0; + size_t tok_idx = 0; + + while (tok_idx < input.length()) { + // variable u is the tree root. + const TrieTree* u = this; + auto ch = input[tok_idx]; + size_t tok_len = 0; + size_t idx_end = tok_idx; + ValueT tok_id = invalid_id_; + + // try to match a longest token + while (u->to_.count(ch)) { + tok_len += 1; + u = u->to_.at(ch).get(); + if (u->value_) { + tok_id = *u->value_; + idx_end = tok_idx + 1; + } + + tok_idx += 1; + if (tok_idx == input.length()) { + break; + } + ch = input[tok_idx]; + } + + tok_idx += 1; + if (tok_id == invalid_id_) { + if (tok_idx < input.length()) { + continue; + } else { + tok_idx += 1; // Assign tok_idx to input.length() + idx_end = tok_idx; + } + } + + auto token_begin_idx = tok_idx - tok_len - 1; // since the tok_idx already moved forward by 1 + tok_len = idx_end - token_begin_idx; + if (token_begin_idx > seg_idx || tok_len == 0) { + tokens.emplace_back(std::basic_string_view(input.data() + seg_idx, token_begin_idx - seg_idx), + invalid_id_); + } + if (tok_id != invalid_id_) { + tokens.emplace_back(std::basic_string_view(input.data() + token_begin_idx, tok_len), tok_id); + tok_idx = idx_end; + } + + // reset state for next match + seg_idx = tok_idx; + } + + return 0; + } + + private: + std::map> to_; + std::optional value_; + const CharT ch_; + const ValueT invalid_id_; +}; + +} // namespace ort_extensions diff --git a/test/test_autotokenizer.py b/test/test_autotokenizer.py index 8c13c72b..eda4fb73 100644 --- a/test/test_autotokenizer.py +++ b/test/test_autotokenizer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import unittest -import pkg_resources import numpy as np from transformers import AutoTokenizer, GPT2Tokenizer @@ -76,7 +75,7 @@ class TestAutoTokenizer(unittest.TestCase): pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0]) actual_ids = ort_tok([text])[0] np.testing.assert_array_equal(ids, actual_ids) - + def test_gpt2_tokenizer(self): tokenizer = GPT2Tokenizer.from_pretrained("Xenova/gpt-4", use_fast=False) text = "Testing words with apostrophes such as you're, i'm, don't, etc." @@ -96,7 +95,7 @@ class TestAutoTokenizer(unittest.TestCase): " add words that should not exist and be tokenized to , such as saoneuhaoesuth") ids = tokenizer.encode(text, return_tensors="np") - ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True}) + ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True}) actual_ids, *_ = ort_inference(ort_tok, [text]) np.testing.assert_array_equal(ids[0], actual_ids) @@ -124,8 +123,7 @@ class TestAutoTokenizer(unittest.TestCase): ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={}) actual_ids, *_ = ort_inference(ort_tok, [code]) self.assertEqual(len(ids['input_ids'].shape), len(actual_ids.shape)) - # TODO: not matched. - # np.testing.assert_array_equal(ids['input_ids'], actual_ids) + np.testing.assert_array_equal(ids['input_ids'], actual_ids) if __name__ == '__main__':