diff --git a/operators/tokenizer/bpe_tokenizer_model.hpp b/operators/tokenizer/bpe_tokenizer_model.hpp index cd6f03f7..e4a1e7c8 100644 --- a/operators/tokenizer/bpe_tokenizer_model.hpp +++ b/operators/tokenizer/bpe_tokenizer_model.hpp @@ -145,23 +145,32 @@ class BpeModel { uint32_t index = 0; auto merge_item = merges_json.begin(); while (merge_item != merges_json.end()) { - std::string line = merge_item.value(); - line.erase(std::remove(line.begin(), line.end(), '\r'), line.end()); - if (line.empty()) continue; - if ((line[0] == '#') && (index == 0)) continue; - auto pos = line.find(' '); - if (pos == std::string::npos) { - return { - kOrtxErrorCorruptData, - "Cannot know how to parse line: " + line, - }; + std::string w1, w2; + if (merge_item->is_string()) { + std::string line = merge_item.value(); + line.erase(std::remove(line.begin(), line.end(), '\r'), line.end()); + if (line.empty()) continue; + if ((line[0] == '#') && (index == 0)) continue; + auto pos = line.find(' '); + if (pos == std::string::npos) { + return { + kOrtxErrorCorruptData, + "Cannot know how to parse line: " + line, + }; + } + w1 = line.substr(0, pos); + w2 = line.substr(pos + 1); + } else if (merge_item->is_array()) { + w1 = merge_item->at(0).get(); + w2 = merge_item->at(1).get(); + } else { + return {kOrtxErrorCorruptData, "Cannot know how to parse line: " + merge_item->dump()}; } - std::string w1 = line.substr(0, pos); - std::string w2 = line.substr(pos + 1); int token_length = ort_extensions::narrow(w1.length() + w2.length()); if (w2.find("") != std::string::npos || w1.find("") != std::string::npos) { token_length -= 4; } + auto iw1 = GetTokenId(w1); auto iw2 = GetTokenId(w2); auto iww = GetTokenId(w1 + w2); diff --git a/operators/tokenizer/ugm_kernels.hpp b/operators/tokenizer/ugm_kernels.hpp index dc19f37b..703c683d 100644 --- a/operators/tokenizer/ugm_kernels.hpp +++ b/operators/tokenizer/ugm_kernels.hpp @@ -18,6 +18,7 @@ #include "op_def_struct.h" #include "base64.h" #include "ustring.h" +#include "narrow.h" #include "nlohmann/json.hpp" #include "trietree.hpp" #include "tokenizer_jsconfig.hpp" @@ -485,7 +486,7 @@ class SpmUgmDecoder { std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin()); } - size_t seq_len = ids_dim.back(); + int64_t seq_len = ids_dim.back(); size_t string_batch = ids.NumberOfElement() / seq_len; std::vector decoded_strings; @@ -495,7 +496,7 @@ class SpmUgmDecoder { std::string text; for (int64_t i = 0; i < seq_len; ++i) { std::string token; - Id2Token(p_ids[i], token, nullptr); + Id2Token(ort_extensions::narrow(p_ids[i]), token, nullptr); if (token.find(spm_escaped_space) == 0) { token = ws + token.substr(spm_escaped_space.length()); } @@ -508,6 +509,7 @@ class SpmUgmDecoder { text = text.substr(1); } } + decoded_strings.push_back(text); }