Add a decoder for Unigram tokenizer and unify some classes among tokenizers (#816)
* rename and formalize the file names * add the decoder impl * fix a typo
This commit is contained in:
Родитель
6b94f4d7a5
Коммит
f204a4c791
|
@ -15,7 +15,7 @@
|
|||
|
||||
#include "ustring.h"
|
||||
#include "narrow.h"
|
||||
#include "tokjson_types.h"
|
||||
#include "tokenizer_common.h"
|
||||
|
||||
struct KernelBpeDecoder {
|
||||
public:
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "file_sys.h"
|
||||
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_jsoncfg.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "base64.h"
|
||||
|
||||
#include <optional>
|
||||
#include <limits>
|
||||
#include "file_sys.h"
|
||||
#include "bpe_kernels.h"
|
||||
#include "tokenizer_jsconfig.hpp"
|
||||
#include "bpe_tokenizer_model.hpp"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
|
@ -673,11 +671,11 @@ struct VectorEqual {
|
|||
}
|
||||
};
|
||||
|
||||
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
for (const auto& token : *added_tokens) {
|
||||
bpe::AddedToken added_token;
|
||||
AddedToken added_token;
|
||||
added_token.id_ = token.value("id", 0);
|
||||
added_token.token_type_ = token.value("__type", "");
|
||||
added_token.content_ = token.value("content", "");
|
||||
|
@ -721,7 +719,7 @@ bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
|
||||
auto post_processor = tok_json.find("post_processor");
|
||||
if (post_processor != tok_json.end()) {
|
||||
|
@ -736,7 +734,7 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
|
|||
}
|
||||
}
|
||||
|
||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs = path(voc_file).open();
|
||||
if (!ifs.is_open()) {
|
||||
|
@ -785,7 +783,7 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
|
|||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs = path(voc_file).open();
|
||||
if (!ifs.is_open()) {
|
||||
|
|
|
@ -8,12 +8,7 @@
|
|||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
#include "ext_status.h"
|
||||
#include "op_def_struct.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "tokjson_types.h"
|
||||
#include "ustring.h"
|
||||
#include "tokenizer_common.h"
|
||||
|
||||
|
||||
struct BpeModelConf {
|
||||
|
@ -116,8 +111,8 @@ struct SpmTokenizer : KernelBpeTokenizer {
|
|||
class JsonFastTokenizer : public KernelBpeTokenizer {
|
||||
public:
|
||||
JsonFastTokenizer();
|
||||
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus Load(const ort_extensions::TokenJsonConfig& config);
|
||||
OrtxStatus LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config);
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
|
||||
|
@ -133,9 +128,9 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
|
|||
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
||||
// template functions to avoid including the huge json header file
|
||||
bool CheckForSpmModel(const json& tok_json);
|
||||
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||
|
||||
BpeModelConf json_conf_;
|
||||
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
|
||||
std::vector<ort_extensions::AddedToken> added_tokens_;
|
||||
};
|
||||
|
|
|
@ -5,26 +5,19 @@
|
|||
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_decoder.hpp"
|
||||
#include "bpe_jsoncfg.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
struct BPEDecoderState {
|
||||
bool f_special_last{};
|
||||
std::string incomplete_utf8_;
|
||||
};
|
||||
} // namespace ort_extensions
|
||||
#include "tokenizer_jsconfig.hpp"
|
||||
#include "bpe_tokenizer_model.hpp"
|
||||
|
||||
class BpeStreamingDecoder : public KernelBpeDecoder {
|
||||
public:
|
||||
BpeStreamingDecoder() = default;
|
||||
~BpeStreamingDecoder() override = default;
|
||||
|
||||
using BPEDecoderState = ort_extensions::BPEDecoderState;
|
||||
using BPEDecoderState = ort_extensions::TokenizerDecodingState;
|
||||
|
||||
// shared the data between the encoder and decoder
|
||||
OrtxStatus Load(
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
|
||||
std::shared_ptr<ort_extensions::TokenJsonConfig const> ptr_config,
|
||||
const JsonFastTokenizer& encoder) {
|
||||
const auto& tok_config = *ptr_config;
|
||||
bos_token_ = tok_config.bos_token_;
|
||||
|
@ -258,5 +251,5 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
|
|||
extTokenId_t eos_token_id_{0};
|
||||
bool add_dummy_prefix_ = false;
|
||||
bool spm_model_{};
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
|
||||
std::shared_ptr<ort_extensions::TokenJsonConfig const> tok_config_;
|
||||
};
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
#include "bpe_utils.hpp"
|
||||
#include "trietree.hpp"
|
||||
#include "tokjson_types.h"
|
||||
#include "tokenizer_common.h"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
|
@ -249,7 +249,7 @@ class BpeModel {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
|
||||
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
|
||||
for (const auto& token : added_tokens) {
|
||||
added_tokens_.Add(ustring(token.content_), 0, token.id_);
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
#include "sentencepiece_processor.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_tokenizer.hpp"
|
||||
#include "sentencepiece_tokenizer.h"
|
||||
#include "string_tensor.h"
|
||||
#include "base64.h"
|
||||
#include "narrow.h"
|
||||
|
|
|
@ -5,11 +5,17 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
#include "ext_status.h"
|
||||
#include "op_def_struct.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
|
||||
#include "ustring.h"
|
||||
|
||||
|
||||
namespace ort_extensions {
|
||||
class BpeModel;
|
||||
|
||||
namespace bpe {
|
||||
|
||||
struct AddedToken final {
|
||||
uint32_t id_{};
|
||||
std::string token_type_;
|
||||
|
@ -23,7 +29,10 @@ struct AddedToken final {
|
|||
|
||||
class TokenJsonConfig; // forward declaration
|
||||
|
||||
} // namespace bpe
|
||||
struct TokenizerDecodingState {
|
||||
bool f_special_last{};
|
||||
std::string incomplete_utf8_;
|
||||
};
|
||||
|
||||
constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
|
||||
} // namespace ort_extensions
|
|
@ -3,14 +3,14 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "file_sys.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "tokjson_types.h"
|
||||
#include "tokenizer_common.h"
|
||||
|
||||
namespace ort_extensions::bpe {
|
||||
namespace ort_extensions {
|
||||
|
||||
// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
|
||||
class TokenJsonConfig final {
|
||||
public:
|
||||
static constexpr const char* kDefaultVocabFile = "tokenizer.json";
|
||||
|
@ -26,6 +26,7 @@ class TokenJsonConfig final {
|
|||
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
||||
}
|
||||
|
||||
|
||||
ortx::path tok_dir(json_path);
|
||||
ortx::path vocab_path(json_path);
|
||||
ortx::path tok_path_obj(json_path);
|
||||
|
@ -122,4 +123,4 @@ class TokenJsonConfig final {
|
|||
std::string module_path_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions::bpe
|
||||
} // namespace ort_extensions
|
|
@ -4,18 +4,18 @@
|
|||
#include "ocos.h"
|
||||
|
||||
#ifdef ENABLE_GPT2_TOKENIZER
|
||||
#include "bpe_tokenizer.hpp"
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_tokenizer_model.hpp"
|
||||
#include "bpe_decoder.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
#include "sentencepiece_tokenizer.hpp"
|
||||
#include "sentencepiece_tokenizer.h"
|
||||
#include "sentencepiece_decoder.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_WORDPIECE_TOKENIZER
|
||||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "wordpiece_tokenizer.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BLINGFIRE
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include <charconv>
|
||||
#include <optional>
|
||||
|
||||
#include "unescape.h"
|
||||
#include "unescape.hpp"
|
||||
#include "trietree.hpp"
|
||||
|
||||
// This Trie Tree is C++ implementation of
|
||||
|
@ -40,6 +40,7 @@ class TrieTokenizer {
|
|||
private:
|
||||
std::map<int, std::string> idx2token;
|
||||
RWKVTrieTree root;
|
||||
using UnescapeUtils = ort_extensions::UnescapeUtils;
|
||||
|
||||
public:
|
||||
TrieTokenizer(const std::string& text_tokens) {
|
||||
|
@ -62,7 +63,7 @@ class TrieTokenizer {
|
|||
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
|
||||
std::string x;
|
||||
int key_length = 0;
|
||||
if (ort_extensions::UnquoteString(raw, x)) {
|
||||
if (UnescapeUtils::UnquoteString(raw, x)) {
|
||||
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
|
||||
}
|
||||
if (x.length() != key_length) {
|
||||
|
|
|
@ -20,10 +20,12 @@
|
|||
#include "ustring.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "trietree.hpp"
|
||||
#include "bpe_jsoncfg.hpp"
|
||||
#include "tokenizer_jsconfig.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class SpmUgmDecoder; // forward declaration
|
||||
|
||||
struct SpmUgmTokenizer {
|
||||
using json = nlohmann::json;
|
||||
using VocabTrieTree = ort_extensions::TrieTree<char, extTokenId_t, -1>;
|
||||
|
@ -109,7 +111,7 @@ struct SpmUgmTokenizer {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Load(const bpe::TokenJsonConfig& config) {
|
||||
OrtxStatus Load(const TokenJsonConfig& config) {
|
||||
ortx::path vocab_path(config.GetVocabDataFile());
|
||||
if (!vocab_path.exists()) {
|
||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
|
||||
|
@ -417,6 +419,7 @@ struct SpmUgmTokenizer {
|
|||
}
|
||||
}
|
||||
|
||||
friend class SpmUgmDecoder;
|
||||
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
||||
static constexpr double unknown_token_score_penalty_ = 10.0;
|
||||
|
||||
|
@ -454,4 +457,84 @@ struct SpmUgmTokenizer {
|
|||
std::string unk_token_ = "<unk>";
|
||||
};
|
||||
|
||||
|
||||
class SpmUgmDecoder {
|
||||
public:
|
||||
SpmUgmDecoder() {
|
||||
}
|
||||
|
||||
OrtxStatus Load(const TokenJsonConfig& config, const SpmUgmTokenizer& tokenizer) {
|
||||
auto vocab_size = tokenizer.vocab_.size();
|
||||
vocab_.resize(vocab_size);
|
||||
for (auto iter = tokenizer.vocab_.begin(); iter != tokenizer.vocab_.end(); ++iter) {
|
||||
vocab_[std::get<0>(iter->second)] = iter->first;
|
||||
}
|
||||
|
||||
unknown_token_ = tokenizer.unk_token_;
|
||||
special_token_ids_ = tokenizer.special_token_ids_;
|
||||
tokenizer_add_space_prefix_ = tokenizer.tokenizer_add_space_prefix_;
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
if (ids_dim.size() > 1) {
|
||||
output_dim.resize(ids_dim.size() - 1);
|
||||
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
|
||||
}
|
||||
|
||||
size_t seq_len = ids_dim.back();
|
||||
size_t string_batch = ids.NumberOfElement() / seq_len;
|
||||
|
||||
std::vector<std::string> decoded_strings;
|
||||
decoded_strings.reserve(string_batch);
|
||||
const std::string ws = " ";
|
||||
for (auto n = string_batch; n > 0; n--) {
|
||||
std::string text;
|
||||
for (int64_t i = 0; i < seq_len; ++i) {
|
||||
std::string token;
|
||||
Id2Token(p_ids[i], token, nullptr);
|
||||
if (token.find(spm_escaped_space) == 0) {
|
||||
token = ws + token.substr(spm_escaped_space.length());
|
||||
}
|
||||
|
||||
text += token;
|
||||
}
|
||||
|
||||
if (tokenizer_add_space_prefix_) {
|
||||
if (text.length() > 0 && text[0] == ' ') {
|
||||
text = text.substr(1);
|
||||
}
|
||||
}
|
||||
decoded_strings.push_back(text);
|
||||
}
|
||||
|
||||
output.SetStringOutput(decoded_strings, output_dim);
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** /* state */) const {
|
||||
if (special_token_ids_.count(id)) {
|
||||
token = "";
|
||||
return {};
|
||||
}
|
||||
|
||||
if (id >= vocab_.size()) {
|
||||
token = unknown_token_;
|
||||
return {};
|
||||
}
|
||||
|
||||
token = vocab_[id];
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
bool tokenizer_add_space_prefix_ = true;
|
||||
std::vector<std::string> vocab_;
|
||||
std::string unknown_token_ = "<unk>";
|
||||
std::set<extTokenId_t> special_token_ids_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -1,161 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "ustring.h"
|
||||
#include <vector>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
|
||||
inline bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
|
||||
|
||||
inline unsigned int hex_digit_to_int(char c) {
|
||||
unsigned int x = static_cast<unsigned char>(c);
|
||||
if (x > '9') {
|
||||
x += 9;
|
||||
}
|
||||
return x & 0xf;
|
||||
}
|
||||
|
||||
inline bool IsSurrogate(char32_t c) {
|
||||
return c >= 0xD800 && c <= 0xDFFF;
|
||||
}
|
||||
|
||||
// Unescape a Python escaped string
|
||||
inline bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
|
||||
|
||||
// reserve enough space for the worst case, and final size will be calculated at the end.
|
||||
unescaped.resize(source.length());
|
||||
char* d = unescaped.data();
|
||||
const char* p = source.data();
|
||||
const char* end = p + source.size();
|
||||
const char* last_byte = end - 1;
|
||||
|
||||
while (p == d && p < end && *p != '\\') p++, d++;
|
||||
|
||||
while (p < end) {
|
||||
if (*p != '\\') {
|
||||
*d++ = *p++;
|
||||
} else {
|
||||
if (++p > last_byte) {
|
||||
return false;
|
||||
}
|
||||
switch (*p) {
|
||||
case 'n':
|
||||
*d++ = '\n';
|
||||
break;
|
||||
case 'r':
|
||||
*d++ = '\r';
|
||||
break;
|
||||
case 't':
|
||||
*d++ = '\t';
|
||||
break;
|
||||
break;
|
||||
case '\\':
|
||||
*d++ = '\\';
|
||||
break;
|
||||
case '\'':
|
||||
*d++ = '\'';
|
||||
break;
|
||||
case '"':
|
||||
*d++ = '\"';
|
||||
break;
|
||||
case 'x':
|
||||
case 'X': {
|
||||
if (p >= last_byte) {
|
||||
return false;
|
||||
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
return false;
|
||||
}
|
||||
unsigned int ch = 0;
|
||||
const char* hex_start = p;
|
||||
while (p < last_byte &&
|
||||
IsHexDigit(static_cast<unsigned char>(p[1])))
|
||||
ch = (ch << 4) + hex_digit_to_int(*++p);
|
||||
if (ch > 0xFF && !is_bytes) {
|
||||
return false;
|
||||
}
|
||||
if (is_bytes) {
|
||||
*d++ = static_cast<char>(ch);
|
||||
} else {
|
||||
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'u': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 4 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
rune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += ustring::EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
case 'U': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 8 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
if (newrune > 0x10FFFF) {
|
||||
return false;
|
||||
} else {
|
||||
rune = newrune;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += ustring::EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
p++;
|
||||
}
|
||||
}
|
||||
|
||||
unescaped.resize(d - unescaped.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool UnquoteString(const std::string& str, std::string& unquoted) {
|
||||
bool is_bytes = false;
|
||||
int idx_0 = 0;
|
||||
if (str.front() == 'b') {
|
||||
is_bytes = true;
|
||||
idx_0 = 1;
|
||||
}
|
||||
std::string str_view(str.data() + idx_0, str.length() - idx_0);
|
||||
if (str_view.length() < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// unescape the string
|
||||
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,168 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include "ustring.h"
|
||||
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class UnescapeUtils {
|
||||
public:
|
||||
static bool IsDigit(char c) { return c >= '0' && c <= '9'; }
|
||||
static bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
|
||||
|
||||
static unsigned int hex_digit_to_int(char c) {
|
||||
unsigned int x = static_cast<unsigned char>(c);
|
||||
if (x > '9') {
|
||||
x += 9;
|
||||
}
|
||||
return x & 0xf;
|
||||
}
|
||||
|
||||
static bool IsSurrogate(char32_t c) {
|
||||
return c >= 0xD800 && c <= 0xDFFF;
|
||||
}
|
||||
|
||||
// Unescape a Python escaped string
|
||||
static bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
|
||||
|
||||
// reserve enough space for the worst case, and final size will be calculated at the end.
|
||||
unescaped.resize(source.length());
|
||||
char* d = unescaped.data();
|
||||
const char* p = source.data();
|
||||
const char* end = p + source.size();
|
||||
const char* last_byte = end - 1;
|
||||
|
||||
while (p == d && p < end && *p != '\\') p++, d++;
|
||||
|
||||
while (p < end) {
|
||||
if (*p != '\\') {
|
||||
*d++ = *p++;
|
||||
} else {
|
||||
if (++p > last_byte) {
|
||||
return false;
|
||||
}
|
||||
switch (*p) {
|
||||
case 'n':
|
||||
*d++ = '\n';
|
||||
break;
|
||||
case 'r':
|
||||
*d++ = '\r';
|
||||
break;
|
||||
case 't':
|
||||
*d++ = '\t';
|
||||
break;
|
||||
break;
|
||||
case '\\':
|
||||
*d++ = '\\';
|
||||
break;
|
||||
case '\'':
|
||||
*d++ = '\'';
|
||||
break;
|
||||
case '"':
|
||||
*d++ = '\"';
|
||||
break;
|
||||
case 'x':
|
||||
case 'X': {
|
||||
if (p >= last_byte) {
|
||||
return false;
|
||||
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
return false;
|
||||
}
|
||||
unsigned int ch = 0;
|
||||
const char* hex_start = p;
|
||||
while (p < last_byte &&
|
||||
IsHexDigit(static_cast<unsigned char>(p[1])))
|
||||
ch = (ch << 4) + hex_digit_to_int(*++p);
|
||||
if (ch > 0xFF && !is_bytes) {
|
||||
return false;
|
||||
}
|
||||
if (is_bytes) {
|
||||
*d++ = static_cast<char>(ch);
|
||||
} else {
|
||||
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'u': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 4 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
rune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += ustring::EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
case 'U': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 8 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
if (newrune > 0x10FFFF) {
|
||||
return false;
|
||||
} else {
|
||||
rune = newrune;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += ustring::EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
p++;
|
||||
}
|
||||
}
|
||||
|
||||
unescaped.resize(d - unescaped.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool UnquoteString(const std::string& str, std::string& unquoted) {
|
||||
bool is_bytes = false;
|
||||
int idx_0 = 0;
|
||||
if (str.front() == 'b') {
|
||||
is_bytes = true;
|
||||
idx_0 = 1;
|
||||
}
|
||||
std::string str_view(str.data() + idx_0, str.length() - idx_0);
|
||||
if (str_view.length() < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// unescape the string
|
||||
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "wordpiece_tokenizer.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
|
|
|
@ -13,7 +13,7 @@ class DetokenizerCache : public OrtxObjectImpl {
|
|||
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
|
||||
~DetokenizerCache() override = default;
|
||||
|
||||
std::unique_ptr<BPEDecoderState> decoder_state_{};
|
||||
std::unique_ptr<TokenizerDecodingState> decoder_state_{};
|
||||
std::string last_text_{}; // last detokenized text
|
||||
};
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
#include "bpe_tokenizer_model.hpp"
|
||||
#include "bpe_decoder.hpp"
|
||||
#include "ugm_kernels.hpp"
|
||||
|
||||
|
@ -16,7 +16,7 @@ TokenizerImpl::TokenizerImpl()
|
|||
TokenizerImpl::~TokenizerImpl() {};
|
||||
|
||||
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
|
||||
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
|
||||
auto status = tok_config_->Load(tok_path);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
|
@ -25,8 +25,18 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
|||
if (tok_config_->tokenizer_class_.empty()) {
|
||||
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
|
||||
status = tokenizer->Load(*tok_config_);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
auto detok = std::make_unique<SpmUgmDecoder>();
|
||||
|
||||
if (status.IsOk()) {
|
||||
status = detok->Load(*tok_config_, *tokenizer);
|
||||
}
|
||||
|
||||
if (status.IsOk()) {
|
||||
tokenizer_ = std::move(tokenizer);
|
||||
detokenizer_ = std::move(detok);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
@ -38,14 +48,16 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
|||
auto fx_load = vocab_file_path.extension() == ".json"?
|
||||
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
|
||||
status = (tokenizer.get()->*fx_load)(*tok_config_);
|
||||
|
||||
if (status.IsOk()) {
|
||||
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
|
||||
status = detokenizer_->Load(tok_config_, *tokenizer);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto detok = std::make_unique<BpeStreamingDecoder>();
|
||||
status = detok->Load(tok_config_, *tokenizer);
|
||||
|
||||
if (status.IsOk()) {
|
||||
tokenizer_ = std::move(tokenizer);
|
||||
detokenizer_ = std::move(detok);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
@ -81,7 +93,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
|
|||
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
|
||||
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
|
||||
ortc::Tensor<std::string> ts_output;
|
||||
OrtxStatus status = detokenizer_->Compute(ts_input, ts_output);
|
||||
OrtxStatus status = std::visit([&](auto& detokenizer) {
|
||||
return detokenizer->Compute(ts_input, ts_output); }, detokenizer_);
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -90,8 +104,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
|
||||
return detokenizer_->Id2Token(id, token, state);
|
||||
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const {
|
||||
return std::visit([&](auto& detokenizer) {
|
||||
return detokenizer->Id2Token(id, token, state); }, detokenizer_);
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> LANGUAGES = {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
#include "bpe_kernels.h"
|
||||
#include "ugm_kernels.hpp"
|
||||
#include "bpe_jsoncfg.hpp"
|
||||
#include "tokenizer_jsconfig.hpp"
|
||||
#include "bpe_streaming.hpp"
|
||||
#include "c_api_utils.hpp"
|
||||
|
||||
|
@ -34,8 +34,8 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
|
||||
BPEDecoderState* state_ptr = cache.get();
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<TokenizerDecodingState>& cache) const {
|
||||
TokenizerDecodingState* state_ptr = cache.get();
|
||||
OrtxStatus status = Id2Token(id, token, &state_ptr);
|
||||
if (status.IsOk()) {
|
||||
if (state_ptr != cache.get()) {
|
||||
|
@ -51,7 +51,7 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
|
||||
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const;
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const;
|
||||
|
||||
OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
|
||||
std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
||||
|
@ -61,8 +61,11 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
|
||||
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;
|
||||
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
|
||||
std::unique_ptr<BpeStreamingDecoder> detokenizer_;
|
||||
using bpe_decoder_t = std::unique_ptr<BpeStreamingDecoder>;
|
||||
using ugm_decoder_t = std::unique_ptr<SpmUgmDecoder>;
|
||||
std::variant<bpe_decoder_t, ugm_decoder_t> detokenizer_;
|
||||
|
||||
std::shared_ptr<ort_extensions::TokenJsonConfig> tok_config_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -335,7 +335,7 @@ TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
|
|||
EXPECT_EQ(token_ids.size(), 1);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
|
||||
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
|
@ -369,7 +369,7 @@ TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
|
|||
DumpTokenIds(token_ids);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||
// std::cout << "\"";
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
|
@ -406,7 +406,7 @@ TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
|
|||
DumpTokenIds(token_ids);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||
// std::cout << "\"";
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
|
@ -464,4 +464,20 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
|
|||
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
|
||||
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
|
||||
|
||||
OrtxObjectPtr<OrtxStringArray> decoded_text;
|
||||
OrtxDetokenize(tokenizer.get(), token_ids.get(), ort_extensions::ptr(decoded_text));
|
||||
EXPECT_EQ(decoded_text.Code(), kOrtxOK);
|
||||
|
||||
const char* text = nullptr;
|
||||
OrtxStringArrayGetItem(decoded_text.get(), 0, &text);
|
||||
// because the tokenization remove the character from the string, the decoded text is not the same as the input text.
|
||||
std::string filtered_text(input[0]);
|
||||
filtered_text.erase(std::remove_if(
|
||||
filtered_text.begin(), filtered_text.end(), [](unsigned char chr){ return chr < 0x20; }), filtered_text.end());
|
||||
// remove the consecutive spaces
|
||||
filtered_text.erase(std::unique(filtered_text.begin(), filtered_text.end(),
|
||||
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }), filtered_text.end());
|
||||
|
||||
EXPECT_STREQ(filtered_text.c_str(), text);
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "string_utils.h"
|
||||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "wordpiece_tokenizer.h"
|
||||
#include "bert_tokenizer.hpp"
|
||||
|
||||
#include <clocale>
|
||||
|
|
Загрузка…
Ссылка в новой задаче