Add a decoder for Unigram tokenizer and unify some classes among tokenizers (#816)
* rename and formalize the file names * add the decoder impl * fix a typo
This commit is contained in:
Родитель
6b94f4d7a5
Коммит
f204a4c791
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
#include "ustring.h"
|
#include "ustring.h"
|
||||||
#include "narrow.h"
|
#include "narrow.h"
|
||||||
#include "tokjson_types.h"
|
#include "tokenizer_common.h"
|
||||||
|
|
||||||
struct KernelBpeDecoder {
|
struct KernelBpeDecoder {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "file_sys.h"
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
#include "bpe_kernels.h"
|
|
||||||
#include "bpe_jsoncfg.hpp"
|
|
||||||
#include "bpe_tokenizer.hpp"
|
|
||||||
|
|
||||||
#include "base64.h"
|
#include "base64.h"
|
||||||
|
#include "file_sys.h"
|
||||||
#include <optional>
|
#include "bpe_kernels.h"
|
||||||
#include <limits>
|
#include "tokenizer_jsconfig.hpp"
|
||||||
|
#include "bpe_tokenizer_model.hpp"
|
||||||
|
|
||||||
using namespace ort_extensions;
|
using namespace ort_extensions;
|
||||||
|
|
||||||
|
@ -673,11 +671,11 @@ struct VectorEqual {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
|
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||||
auto added_tokens = tok_json.find("added_tokens");
|
auto added_tokens = tok_json.find("added_tokens");
|
||||||
if (added_tokens != tok_json.end()) {
|
if (added_tokens != tok_json.end()) {
|
||||||
for (const auto& token : *added_tokens) {
|
for (const auto& token : *added_tokens) {
|
||||||
bpe::AddedToken added_token;
|
AddedToken added_token;
|
||||||
added_token.id_ = token.value("id", 0);
|
added_token.id_ = token.value("id", 0);
|
||||||
added_token.token_type_ = token.value("__type", "");
|
added_token.token_type_ = token.value("__type", "");
|
||||||
added_token.content_ = token.value("content", "");
|
added_token.content_ = token.value("content", "");
|
||||||
|
@ -721,7 +719,7 @@ bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
|
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||||
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
|
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
|
||||||
auto post_processor = tok_json.find("post_processor");
|
auto post_processor = tok_json.find("post_processor");
|
||||||
if (post_processor != tok_json.end()) {
|
if (post_processor != tok_json.end()) {
|
||||||
|
@ -736,7 +734,7 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
|
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
|
||||||
std::string voc_file = config.GetVocabDataFile();
|
std::string voc_file = config.GetVocabDataFile();
|
||||||
std::ifstream ifs = path(voc_file).open();
|
std::ifstream ifs = path(voc_file).open();
|
||||||
if (!ifs.is_open()) {
|
if (!ifs.is_open()) {
|
||||||
|
@ -785,7 +783,7 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config) {
|
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
|
||||||
std::string voc_file = config.GetVocabDataFile();
|
std::string voc_file = config.GetVocabDataFile();
|
||||||
std::ifstream ifs = path(voc_file).open();
|
std::ifstream ifs = path(voc_file).open();
|
||||||
if (!ifs.is_open()) {
|
if (!ifs.is_open()) {
|
||||||
|
|
|
@ -8,12 +8,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "ortx_tokenizer.h"
|
#include "tokenizer_common.h"
|
||||||
#include "ext_status.h"
|
|
||||||
#include "op_def_struct.h"
|
|
||||||
#include "nlohmann/json_fwd.hpp"
|
|
||||||
#include "tokjson_types.h"
|
|
||||||
#include "ustring.h"
|
|
||||||
|
|
||||||
|
|
||||||
struct BpeModelConf {
|
struct BpeModelConf {
|
||||||
|
@ -116,8 +111,8 @@ struct SpmTokenizer : KernelBpeTokenizer {
|
||||||
class JsonFastTokenizer : public KernelBpeTokenizer {
|
class JsonFastTokenizer : public KernelBpeTokenizer {
|
||||||
public:
|
public:
|
||||||
JsonFastTokenizer();
|
JsonFastTokenizer();
|
||||||
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
|
OrtxStatus Load(const ort_extensions::TokenJsonConfig& config);
|
||||||
OrtxStatus LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config);
|
OrtxStatus LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config);
|
||||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||||
ortc::Tensor<int64_t>& tokenize_output,
|
ortc::Tensor<int64_t>& tokenize_output,
|
||||||
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
|
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
|
||||||
|
@ -133,9 +128,9 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
|
||||||
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
||||||
// template functions to avoid including the huge json header file
|
// template functions to avoid including the huge json header file
|
||||||
bool CheckForSpmModel(const json& tok_json);
|
bool CheckForSpmModel(const json& tok_json);
|
||||||
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
|
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||||
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
|
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||||
|
|
||||||
BpeModelConf json_conf_;
|
BpeModelConf json_conf_;
|
||||||
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
|
std::vector<ort_extensions::AddedToken> added_tokens_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -5,26 +5,19 @@
|
||||||
|
|
||||||
#include "bpe_kernels.h"
|
#include "bpe_kernels.h"
|
||||||
#include "bpe_decoder.hpp"
|
#include "bpe_decoder.hpp"
|
||||||
#include "bpe_jsoncfg.hpp"
|
#include "tokenizer_jsconfig.hpp"
|
||||||
#include "bpe_tokenizer.hpp"
|
#include "bpe_tokenizer_model.hpp"
|
||||||
|
|
||||||
namespace ort_extensions {
|
|
||||||
struct BPEDecoderState {
|
|
||||||
bool f_special_last{};
|
|
||||||
std::string incomplete_utf8_;
|
|
||||||
};
|
|
||||||
} // namespace ort_extensions
|
|
||||||
|
|
||||||
class BpeStreamingDecoder : public KernelBpeDecoder {
|
class BpeStreamingDecoder : public KernelBpeDecoder {
|
||||||
public:
|
public:
|
||||||
BpeStreamingDecoder() = default;
|
BpeStreamingDecoder() = default;
|
||||||
~BpeStreamingDecoder() override = default;
|
~BpeStreamingDecoder() override = default;
|
||||||
|
|
||||||
using BPEDecoderState = ort_extensions::BPEDecoderState;
|
using BPEDecoderState = ort_extensions::TokenizerDecodingState;
|
||||||
|
|
||||||
// shared the data between the encoder and decoder
|
// shared the data between the encoder and decoder
|
||||||
OrtxStatus Load(
|
OrtxStatus Load(
|
||||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
|
std::shared_ptr<ort_extensions::TokenJsonConfig const> ptr_config,
|
||||||
const JsonFastTokenizer& encoder) {
|
const JsonFastTokenizer& encoder) {
|
||||||
const auto& tok_config = *ptr_config;
|
const auto& tok_config = *ptr_config;
|
||||||
bos_token_ = tok_config.bos_token_;
|
bos_token_ = tok_config.bos_token_;
|
||||||
|
@ -258,5 +251,5 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
|
||||||
extTokenId_t eos_token_id_{0};
|
extTokenId_t eos_token_id_{0};
|
||||||
bool add_dummy_prefix_ = false;
|
bool add_dummy_prefix_ = false;
|
||||||
bool spm_model_{};
|
bool spm_model_{};
|
||||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
|
std::shared_ptr<ort_extensions::TokenJsonConfig const> tok_config_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
#include "bpe_utils.hpp"
|
#include "bpe_utils.hpp"
|
||||||
#include "trietree.hpp"
|
#include "trietree.hpp"
|
||||||
#include "tokjson_types.h"
|
#include "tokenizer_common.h"
|
||||||
|
|
||||||
namespace ort_extensions {
|
namespace ort_extensions {
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ class BpeModel {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
|
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
|
||||||
for (const auto& token : added_tokens) {
|
for (const auto& token : added_tokens) {
|
||||||
added_tokens_.Add(ustring(token.content_), 0, token.id_);
|
added_tokens_.Add(ustring(token.content_), 0, token.id_);
|
||||||
}
|
}
|
|
@ -4,7 +4,7 @@
|
||||||
#include "sentencepiece_processor.h"
|
#include "sentencepiece_processor.h"
|
||||||
#include "sentencepiece_model.pb.h"
|
#include "sentencepiece_model.pb.h"
|
||||||
#include "sentencepiece.pb.h"
|
#include "sentencepiece.pb.h"
|
||||||
#include "sentencepiece_tokenizer.hpp"
|
#include "sentencepiece_tokenizer.h"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
#include "base64.h"
|
#include "base64.h"
|
||||||
#include "narrow.h"
|
#include "narrow.h"
|
||||||
|
|
|
@ -5,11 +5,17 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "ortx_tokenizer.h"
|
||||||
|
#include "ext_status.h"
|
||||||
|
#include "op_def_struct.h"
|
||||||
|
#include "nlohmann/json_fwd.hpp"
|
||||||
|
|
||||||
|
#include "ustring.h"
|
||||||
|
|
||||||
|
|
||||||
namespace ort_extensions {
|
namespace ort_extensions {
|
||||||
class BpeModel;
|
class BpeModel;
|
||||||
|
|
||||||
namespace bpe {
|
|
||||||
|
|
||||||
struct AddedToken final {
|
struct AddedToken final {
|
||||||
uint32_t id_{};
|
uint32_t id_{};
|
||||||
std::string token_type_;
|
std::string token_type_;
|
||||||
|
@ -23,7 +29,10 @@ struct AddedToken final {
|
||||||
|
|
||||||
class TokenJsonConfig; // forward declaration
|
class TokenJsonConfig; // forward declaration
|
||||||
|
|
||||||
} // namespace bpe
|
struct TokenizerDecodingState {
|
||||||
|
bool f_special_last{};
|
||||||
|
std::string incomplete_utf8_;
|
||||||
|
};
|
||||||
|
|
||||||
constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
|
constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
|
||||||
} // namespace ort_extensions
|
} // namespace ort_extensions
|
|
@ -3,14 +3,14 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ocos.h"
|
|
||||||
#include "file_sys.h"
|
#include "file_sys.h"
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
#include "tokjson_types.h"
|
#include "tokenizer_common.h"
|
||||||
|
|
||||||
namespace ort_extensions::bpe {
|
namespace ort_extensions {
|
||||||
|
|
||||||
|
// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
|
||||||
class TokenJsonConfig final {
|
class TokenJsonConfig final {
|
||||||
public:
|
public:
|
||||||
static constexpr const char* kDefaultVocabFile = "tokenizer.json";
|
static constexpr const char* kDefaultVocabFile = "tokenizer.json";
|
||||||
|
@ -26,6 +26,7 @@ class TokenJsonConfig final {
|
||||||
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ortx::path tok_dir(json_path);
|
ortx::path tok_dir(json_path);
|
||||||
ortx::path vocab_path(json_path);
|
ortx::path vocab_path(json_path);
|
||||||
ortx::path tok_path_obj(json_path);
|
ortx::path tok_path_obj(json_path);
|
||||||
|
@ -122,4 +123,4 @@ class TokenJsonConfig final {
|
||||||
std::string module_path_;
|
std::string module_path_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ort_extensions::bpe
|
} // namespace ort_extensions
|
|
@ -4,18 +4,18 @@
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
#ifdef ENABLE_GPT2_TOKENIZER
|
#ifdef ENABLE_GPT2_TOKENIZER
|
||||||
#include "bpe_tokenizer.hpp"
|
|
||||||
#include "bpe_kernels.h"
|
#include "bpe_kernels.h"
|
||||||
|
#include "bpe_tokenizer_model.hpp"
|
||||||
#include "bpe_decoder.hpp"
|
#include "bpe_decoder.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_SPM_TOKENIZER
|
#ifdef ENABLE_SPM_TOKENIZER
|
||||||
#include "sentencepiece_tokenizer.hpp"
|
#include "sentencepiece_tokenizer.h"
|
||||||
#include "sentencepiece_decoder.hpp"
|
#include "sentencepiece_decoder.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_WORDPIECE_TOKENIZER
|
#ifdef ENABLE_WORDPIECE_TOKENIZER
|
||||||
#include "wordpiece_tokenizer.hpp"
|
#include "wordpiece_tokenizer.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_BLINGFIRE
|
#ifdef ENABLE_BLINGFIRE
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include <charconv>
|
#include <charconv>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "unescape.h"
|
#include "unescape.hpp"
|
||||||
#include "trietree.hpp"
|
#include "trietree.hpp"
|
||||||
|
|
||||||
// This Trie Tree is C++ implementation of
|
// This Trie Tree is C++ implementation of
|
||||||
|
@ -40,6 +40,7 @@ class TrieTokenizer {
|
||||||
private:
|
private:
|
||||||
std::map<int, std::string> idx2token;
|
std::map<int, std::string> idx2token;
|
||||||
RWKVTrieTree root;
|
RWKVTrieTree root;
|
||||||
|
using UnescapeUtils = ort_extensions::UnescapeUtils;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TrieTokenizer(const std::string& text_tokens) {
|
TrieTokenizer(const std::string& text_tokens) {
|
||||||
|
@ -62,7 +63,7 @@ class TrieTokenizer {
|
||||||
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
|
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
|
||||||
std::string x;
|
std::string x;
|
||||||
int key_length = 0;
|
int key_length = 0;
|
||||||
if (ort_extensions::UnquoteString(raw, x)) {
|
if (UnescapeUtils::UnquoteString(raw, x)) {
|
||||||
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
|
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
|
||||||
}
|
}
|
||||||
if (x.length() != key_length) {
|
if (x.length() != key_length) {
|
||||||
|
|
|
@ -20,10 +20,12 @@
|
||||||
#include "ustring.h"
|
#include "ustring.h"
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
#include "trietree.hpp"
|
#include "trietree.hpp"
|
||||||
#include "bpe_jsoncfg.hpp"
|
#include "tokenizer_jsconfig.hpp"
|
||||||
|
|
||||||
namespace ort_extensions {
|
namespace ort_extensions {
|
||||||
|
|
||||||
|
class SpmUgmDecoder; // forward declaration
|
||||||
|
|
||||||
struct SpmUgmTokenizer {
|
struct SpmUgmTokenizer {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
using VocabTrieTree = ort_extensions::TrieTree<char, extTokenId_t, -1>;
|
using VocabTrieTree = ort_extensions::TrieTree<char, extTokenId_t, -1>;
|
||||||
|
@ -109,7 +111,7 @@ struct SpmUgmTokenizer {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus Load(const bpe::TokenJsonConfig& config) {
|
OrtxStatus Load(const TokenJsonConfig& config) {
|
||||||
ortx::path vocab_path(config.GetVocabDataFile());
|
ortx::path vocab_path(config.GetVocabDataFile());
|
||||||
if (!vocab_path.exists()) {
|
if (!vocab_path.exists()) {
|
||||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
|
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
|
||||||
|
@ -417,6 +419,7 @@ struct SpmUgmTokenizer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
friend class SpmUgmDecoder;
|
||||||
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
||||||
static constexpr double unknown_token_score_penalty_ = 10.0;
|
static constexpr double unknown_token_score_penalty_ = 10.0;
|
||||||
|
|
||||||
|
@ -454,4 +457,84 @@ struct SpmUgmTokenizer {
|
||||||
std::string unk_token_ = "<unk>";
|
std::string unk_token_ = "<unk>";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class SpmUgmDecoder {
|
||||||
|
public:
|
||||||
|
SpmUgmDecoder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtxStatus Load(const TokenJsonConfig& config, const SpmUgmTokenizer& tokenizer) {
|
||||||
|
auto vocab_size = tokenizer.vocab_.size();
|
||||||
|
vocab_.resize(vocab_size);
|
||||||
|
for (auto iter = tokenizer.vocab_.begin(); iter != tokenizer.vocab_.end(); ++iter) {
|
||||||
|
vocab_[std::get<0>(iter->second)] = iter->first;
|
||||||
|
}
|
||||||
|
|
||||||
|
unknown_token_ = tokenizer.unk_token_;
|
||||||
|
special_token_ids_ = tokenizer.special_token_ids_;
|
||||||
|
tokenizer_add_space_prefix_ = tokenizer.tokenizer_add_space_prefix_;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
|
||||||
|
const int64_t* p_ids = ids.Data();
|
||||||
|
const auto& ids_dim = ids.Shape();
|
||||||
|
std::vector<int64_t> output_dim = {1};
|
||||||
|
if (ids_dim.size() > 1) {
|
||||||
|
output_dim.resize(ids_dim.size() - 1);
|
||||||
|
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t seq_len = ids_dim.back();
|
||||||
|
size_t string_batch = ids.NumberOfElement() / seq_len;
|
||||||
|
|
||||||
|
std::vector<std::string> decoded_strings;
|
||||||
|
decoded_strings.reserve(string_batch);
|
||||||
|
const std::string ws = " ";
|
||||||
|
for (auto n = string_batch; n > 0; n--) {
|
||||||
|
std::string text;
|
||||||
|
for (int64_t i = 0; i < seq_len; ++i) {
|
||||||
|
std::string token;
|
||||||
|
Id2Token(p_ids[i], token, nullptr);
|
||||||
|
if (token.find(spm_escaped_space) == 0) {
|
||||||
|
token = ws + token.substr(spm_escaped_space.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
text += token;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tokenizer_add_space_prefix_) {
|
||||||
|
if (text.length() > 0 && text[0] == ' ') {
|
||||||
|
text = text.substr(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decoded_strings.push_back(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
output.SetStringOutput(decoded_strings, output_dim);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** /* state */) const {
|
||||||
|
if (special_token_ids_.count(id)) {
|
||||||
|
token = "";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (id >= vocab_.size()) {
|
||||||
|
token = unknown_token_;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
token = vocab_[id];
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool tokenizer_add_space_prefix_ = true;
|
||||||
|
std::vector<std::string> vocab_;
|
||||||
|
std::string unknown_token_ = "<unk>";
|
||||||
|
std::set<extTokenId_t> special_token_ids_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace ort_extensions
|
} // namespace ort_extensions
|
||||||
|
|
|
@ -1,161 +0,0 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT License.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ustring.h"
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace ort_extensions {
|
|
||||||
|
|
||||||
inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
|
|
||||||
inline bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
|
|
||||||
|
|
||||||
inline unsigned int hex_digit_to_int(char c) {
|
|
||||||
unsigned int x = static_cast<unsigned char>(c);
|
|
||||||
if (x > '9') {
|
|
||||||
x += 9;
|
|
||||||
}
|
|
||||||
return x & 0xf;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IsSurrogate(char32_t c) {
|
|
||||||
return c >= 0xD800 && c <= 0xDFFF;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unescape a Python escaped string
|
|
||||||
inline bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
|
|
||||||
|
|
||||||
// reserve enough space for the worst case, and final size will be calculated at the end.
|
|
||||||
unescaped.resize(source.length());
|
|
||||||
char* d = unescaped.data();
|
|
||||||
const char* p = source.data();
|
|
||||||
const char* end = p + source.size();
|
|
||||||
const char* last_byte = end - 1;
|
|
||||||
|
|
||||||
while (p == d && p < end && *p != '\\') p++, d++;
|
|
||||||
|
|
||||||
while (p < end) {
|
|
||||||
if (*p != '\\') {
|
|
||||||
*d++ = *p++;
|
|
||||||
} else {
|
|
||||||
if (++p > last_byte) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
switch (*p) {
|
|
||||||
case 'n':
|
|
||||||
*d++ = '\n';
|
|
||||||
break;
|
|
||||||
case 'r':
|
|
||||||
*d++ = '\r';
|
|
||||||
break;
|
|
||||||
case 't':
|
|
||||||
*d++ = '\t';
|
|
||||||
break;
|
|
||||||
break;
|
|
||||||
case '\\':
|
|
||||||
*d++ = '\\';
|
|
||||||
break;
|
|
||||||
case '\'':
|
|
||||||
*d++ = '\'';
|
|
||||||
break;
|
|
||||||
case '"':
|
|
||||||
*d++ = '\"';
|
|
||||||
break;
|
|
||||||
case 'x':
|
|
||||||
case 'X': {
|
|
||||||
if (p >= last_byte) {
|
|
||||||
return false;
|
|
||||||
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
unsigned int ch = 0;
|
|
||||||
const char* hex_start = p;
|
|
||||||
while (p < last_byte &&
|
|
||||||
IsHexDigit(static_cast<unsigned char>(p[1])))
|
|
||||||
ch = (ch << 4) + hex_digit_to_int(*++p);
|
|
||||||
if (ch > 0xFF && !is_bytes) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (is_bytes) {
|
|
||||||
*d++ = static_cast<char>(ch);
|
|
||||||
} else {
|
|
||||||
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'u': {
|
|
||||||
char32_t rune = 0;
|
|
||||||
const char* hex_start = p;
|
|
||||||
if (p + 4 >= end) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
|
||||||
rune = (rune << 4) + hex_digit_to_int(*++p);
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (IsSurrogate(rune)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
d += ustring::EncodeUTF8Char(d, rune);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'U': {
|
|
||||||
char32_t rune = 0;
|
|
||||||
const char* hex_start = p;
|
|
||||||
if (p + 8 >= end) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 8; ++i) {
|
|
||||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
|
||||||
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
|
|
||||||
if (newrune > 0x10FFFF) {
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
rune = newrune;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (IsSurrogate(rune)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
d += ustring::EncodeUTF8Char(d, rune);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unescaped.resize(d - unescaped.data());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool UnquoteString(const std::string& str, std::string& unquoted) {
|
|
||||||
bool is_bytes = false;
|
|
||||||
int idx_0 = 0;
|
|
||||||
if (str.front() == 'b') {
|
|
||||||
is_bytes = true;
|
|
||||||
idx_0 = 1;
|
|
||||||
}
|
|
||||||
std::string str_view(str.data() + idx_0, str.length() - idx_0);
|
|
||||||
if (str_view.length() < 2) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// unescape the string
|
|
||||||
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace ort_extensions
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
#include "ustring.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace ort_extensions {
|
||||||
|
|
||||||
|
class UnescapeUtils {
|
||||||
|
public:
|
||||||
|
static bool IsDigit(char c) { return c >= '0' && c <= '9'; }
|
||||||
|
static bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
|
||||||
|
|
||||||
|
static unsigned int hex_digit_to_int(char c) {
|
||||||
|
unsigned int x = static_cast<unsigned char>(c);
|
||||||
|
if (x > '9') {
|
||||||
|
x += 9;
|
||||||
|
}
|
||||||
|
return x & 0xf;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSurrogate(char32_t c) {
|
||||||
|
return c >= 0xD800 && c <= 0xDFFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unescape a Python escaped string
|
||||||
|
static bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
|
||||||
|
|
||||||
|
// reserve enough space for the worst case, and final size will be calculated at the end.
|
||||||
|
unescaped.resize(source.length());
|
||||||
|
char* d = unescaped.data();
|
||||||
|
const char* p = source.data();
|
||||||
|
const char* end = p + source.size();
|
||||||
|
const char* last_byte = end - 1;
|
||||||
|
|
||||||
|
while (p == d && p < end && *p != '\\') p++, d++;
|
||||||
|
|
||||||
|
while (p < end) {
|
||||||
|
if (*p != '\\') {
|
||||||
|
*d++ = *p++;
|
||||||
|
} else {
|
||||||
|
if (++p > last_byte) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
switch (*p) {
|
||||||
|
case 'n':
|
||||||
|
*d++ = '\n';
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
*d++ = '\r';
|
||||||
|
break;
|
||||||
|
case 't':
|
||||||
|
*d++ = '\t';
|
||||||
|
break;
|
||||||
|
break;
|
||||||
|
case '\\':
|
||||||
|
*d++ = '\\';
|
||||||
|
break;
|
||||||
|
case '\'':
|
||||||
|
*d++ = '\'';
|
||||||
|
break;
|
||||||
|
case '"':
|
||||||
|
*d++ = '\"';
|
||||||
|
break;
|
||||||
|
case 'x':
|
||||||
|
case 'X': {
|
||||||
|
if (p >= last_byte) {
|
||||||
|
return false;
|
||||||
|
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
unsigned int ch = 0;
|
||||||
|
const char* hex_start = p;
|
||||||
|
while (p < last_byte &&
|
||||||
|
IsHexDigit(static_cast<unsigned char>(p[1])))
|
||||||
|
ch = (ch << 4) + hex_digit_to_int(*++p);
|
||||||
|
if (ch > 0xFF && !is_bytes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (is_bytes) {
|
||||||
|
*d++ = static_cast<char>(ch);
|
||||||
|
} else {
|
||||||
|
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 'u': {
|
||||||
|
char32_t rune = 0;
|
||||||
|
const char* hex_start = p;
|
||||||
|
if (p + 4 >= end) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||||
|
rune = (rune << 4) + hex_digit_to_int(*++p);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (IsSurrogate(rune)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
d += ustring::EncodeUTF8Char(d, rune);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 'U': {
|
||||||
|
char32_t rune = 0;
|
||||||
|
const char* hex_start = p;
|
||||||
|
if (p + 8 >= end) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||||
|
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
|
||||||
|
if (newrune > 0x10FFFF) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
rune = newrune;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (IsSurrogate(rune)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
d += ustring::EncodeUTF8Char(d, rune);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unescaped.resize(d - unescaped.data());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool UnquoteString(const std::string& str, std::string& unquoted) {
|
||||||
|
bool is_bytes = false;
|
||||||
|
int idx_0 = 0;
|
||||||
|
if (str.front() == 'b') {
|
||||||
|
is_bytes = true;
|
||||||
|
idx_0 = 1;
|
||||||
|
}
|
||||||
|
std::string str_view(str.data() + idx_0, str.length() - idx_0);
|
||||||
|
if (str_view.length() < 2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unescape the string
|
||||||
|
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ort_extensions
|
|
@ -1,7 +1,7 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "wordpiece_tokenizer.hpp"
|
#include "wordpiece_tokenizer.h"
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||||
|
|
|
@ -13,7 +13,7 @@ class DetokenizerCache : public OrtxObjectImpl {
|
||||||
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
|
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
|
||||||
~DetokenizerCache() override = default;
|
~DetokenizerCache() override = default;
|
||||||
|
|
||||||
std::unique_ptr<BPEDecoderState> decoder_state_{};
|
std::unique_ptr<TokenizerDecodingState> decoder_state_{};
|
||||||
std::string last_text_{}; // last detokenized text
|
std::string last_text_{}; // last detokenized text
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "bpe_kernels.h"
|
#include "bpe_kernels.h"
|
||||||
#include "bpe_tokenizer.hpp"
|
#include "bpe_tokenizer_model.hpp"
|
||||||
#include "bpe_decoder.hpp"
|
#include "bpe_decoder.hpp"
|
||||||
#include "ugm_kernels.hpp"
|
#include "ugm_kernels.hpp"
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ TokenizerImpl::TokenizerImpl()
|
||||||
TokenizerImpl::~TokenizerImpl() {};
|
TokenizerImpl::~TokenizerImpl() {};
|
||||||
|
|
||||||
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||||
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
|
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
|
||||||
auto status = tok_config_->Load(tok_path);
|
auto status = tok_config_->Load(tok_path);
|
||||||
if (!status.IsOk()) {
|
if (!status.IsOk()) {
|
||||||
return status;
|
return status;
|
||||||
|
@ -25,8 +25,18 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||||
if (tok_config_->tokenizer_class_.empty()) {
|
if (tok_config_->tokenizer_class_.empty()) {
|
||||||
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
|
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
|
||||||
status = tokenizer->Load(*tok_config_);
|
status = tokenizer->Load(*tok_config_);
|
||||||
|
if (!status.IsOk()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
auto detok = std::make_unique<SpmUgmDecoder>();
|
||||||
|
|
||||||
|
if (status.IsOk()) {
|
||||||
|
status = detok->Load(*tok_config_, *tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
if (status.IsOk()) {
|
if (status.IsOk()) {
|
||||||
tokenizer_ = std::move(tokenizer);
|
tokenizer_ = std::move(tokenizer);
|
||||||
|
detokenizer_ = std::move(detok);
|
||||||
}
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
|
@ -38,14 +48,16 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||||
auto fx_load = vocab_file_path.extension() == ".json"?
|
auto fx_load = vocab_file_path.extension() == ".json"?
|
||||||
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
|
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
|
||||||
status = (tokenizer.get()->*fx_load)(*tok_config_);
|
status = (tokenizer.get()->*fx_load)(*tok_config_);
|
||||||
|
if (!status.IsOk()) {
|
||||||
if (status.IsOk()) {
|
return status;
|
||||||
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
|
|
||||||
status = detokenizer_->Load(tok_config_, *tokenizer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto detok = std::make_unique<BpeStreamingDecoder>();
|
||||||
|
status = detok->Load(tok_config_, *tokenizer);
|
||||||
|
|
||||||
if (status.IsOk()) {
|
if (status.IsOk()) {
|
||||||
tokenizer_ = std::move(tokenizer);
|
tokenizer_ = std::move(tokenizer);
|
||||||
|
detokenizer_ = std::move(detok);
|
||||||
}
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
|
@ -81,7 +93,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
|
||||||
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
|
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
|
||||||
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
|
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
|
||||||
ortc::Tensor<std::string> ts_output;
|
ortc::Tensor<std::string> ts_output;
|
||||||
OrtxStatus status = detokenizer_->Compute(ts_input, ts_output);
|
OrtxStatus status = std::visit([&](auto& detokenizer) {
|
||||||
|
return detokenizer->Compute(ts_input, ts_output); }, detokenizer_);
|
||||||
|
|
||||||
if (!status.IsOk()) {
|
if (!status.IsOk()) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -90,8 +104,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
|
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const {
|
||||||
return detokenizer_->Id2Token(id, token, state);
|
return std::visit([&](auto& detokenizer) {
|
||||||
|
return detokenizer->Id2Token(id, token, state); }, detokenizer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::map<std::string, std::string> LANGUAGES = {
|
static std::map<std::string, std::string> LANGUAGES = {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
#include "bpe_kernels.h"
|
#include "bpe_kernels.h"
|
||||||
#include "ugm_kernels.hpp"
|
#include "ugm_kernels.hpp"
|
||||||
#include "bpe_jsoncfg.hpp"
|
#include "tokenizer_jsconfig.hpp"
|
||||||
#include "bpe_streaming.hpp"
|
#include "bpe_streaming.hpp"
|
||||||
#include "c_api_utils.hpp"
|
#include "c_api_utils.hpp"
|
||||||
|
|
||||||
|
@ -34,8 +34,8 @@ class TokenizerImpl : public OrtxObjectImpl {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
|
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<TokenizerDecodingState>& cache) const {
|
||||||
BPEDecoderState* state_ptr = cache.get();
|
TokenizerDecodingState* state_ptr = cache.get();
|
||||||
OrtxStatus status = Id2Token(id, token, &state_ptr);
|
OrtxStatus status = Id2Token(id, token, &state_ptr);
|
||||||
if (status.IsOk()) {
|
if (status.IsOk()) {
|
||||||
if (state_ptr != cache.get()) {
|
if (state_ptr != cache.get()) {
|
||||||
|
@ -51,7 +51,7 @@ class TokenizerImpl : public OrtxObjectImpl {
|
||||||
|
|
||||||
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
|
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
|
||||||
|
|
||||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const;
|
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const;
|
||||||
|
|
||||||
OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
|
OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
|
||||||
std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
||||||
|
@ -61,8 +61,11 @@ class TokenizerImpl : public OrtxObjectImpl {
|
||||||
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
|
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
|
||||||
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;
|
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;
|
||||||
|
|
||||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
|
using bpe_decoder_t = std::unique_ptr<BpeStreamingDecoder>;
|
||||||
std::unique_ptr<BpeStreamingDecoder> detokenizer_;
|
using ugm_decoder_t = std::unique_ptr<SpmUgmDecoder>;
|
||||||
|
std::variant<bpe_decoder_t, ugm_decoder_t> detokenizer_;
|
||||||
|
|
||||||
|
std::shared_ptr<ort_extensions::TokenJsonConfig> tok_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ort_extensions
|
} // namespace ort_extensions
|
||||||
|
|
|
@ -335,7 +335,7 @@ TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
|
||||||
EXPECT_EQ(token_ids.size(), 1);
|
EXPECT_EQ(token_ids.size(), 1);
|
||||||
|
|
||||||
std::string text;
|
std::string text;
|
||||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||||
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
|
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
|
||||||
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
||||||
for (const auto& token_id : token_ids[0]) {
|
for (const auto& token_id : token_ids[0]) {
|
||||||
|
@ -369,7 +369,7 @@ TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
|
||||||
DumpTokenIds(token_ids);
|
DumpTokenIds(token_ids);
|
||||||
|
|
||||||
std::string text;
|
std::string text;
|
||||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||||
// std::cout << "\"";
|
// std::cout << "\"";
|
||||||
for (const auto& token_id : token_ids[0]) {
|
for (const auto& token_id : token_ids[0]) {
|
||||||
std::string token;
|
std::string token;
|
||||||
|
@ -406,7 +406,7 @@ TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
|
||||||
DumpTokenIds(token_ids);
|
DumpTokenIds(token_ids);
|
||||||
|
|
||||||
std::string text;
|
std::string text;
|
||||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
|
||||||
// std::cout << "\"";
|
// std::cout << "\"";
|
||||||
for (const auto& token_id : token_ids[0]) {
|
for (const auto& token_id : token_ids[0]) {
|
||||||
std::string token;
|
std::string token;
|
||||||
|
@ -464,4 +464,20 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
|
||||||
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||||
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
|
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
|
||||||
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
|
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
|
||||||
|
|
||||||
|
OrtxObjectPtr<OrtxStringArray> decoded_text;
|
||||||
|
OrtxDetokenize(tokenizer.get(), token_ids.get(), ort_extensions::ptr(decoded_text));
|
||||||
|
EXPECT_EQ(decoded_text.Code(), kOrtxOK);
|
||||||
|
|
||||||
|
const char* text = nullptr;
|
||||||
|
OrtxStringArrayGetItem(decoded_text.get(), 0, &text);
|
||||||
|
// because the tokenization remove the character from the string, the decoded text is not the same as the input text.
|
||||||
|
std::string filtered_text(input[0]);
|
||||||
|
filtered_text.erase(std::remove_if(
|
||||||
|
filtered_text.begin(), filtered_text.end(), [](unsigned char chr){ return chr < 0x20; }), filtered_text.end());
|
||||||
|
// remove the consecutive spaces
|
||||||
|
filtered_text.erase(std::unique(filtered_text.begin(), filtered_text.end(),
|
||||||
|
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }), filtered_text.end());
|
||||||
|
|
||||||
|
EXPECT_STREQ(filtered_text.c_str(), text);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
#include "wordpiece_tokenizer.hpp"
|
#include "wordpiece_tokenizer.h"
|
||||||
#include "bert_tokenizer.hpp"
|
#include "bert_tokenizer.hpp"
|
||||||
|
|
||||||
#include <clocale>
|
#include <clocale>
|
||||||
|
|
Загрузка…
Ссылка в новой задаче