* initial checkins

* fix the selectedops build failures

* add the tokenization implementation

* update the windows DEF file for c abi in cmake file

* fix the build on linux

* fix some warnings and remove the unused code

* initial import of unit tests from tfmtok

* add streaming API support

* fix the merges loading issues

* complete export from tfmtok - needs input id fixing

* fix the unit test failures.

* fix all unit test failure

* refactor streaming code

* remove the unused code

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Wenbing Li 2024-04-29 16:45:49 -07:00 коммит произвёл GitHub
Родитель 1f31d33ed4
Коммит a8bce4328b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
38 изменённых файлов: 1544077 добавлений и 156 удалений

Просмотреть файл

@ -564,6 +564,7 @@ standardize_output_folder(ocos_operators)
target_include_directories(noexcep_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${GSL_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/include/custom_op
${PROJECT_SOURCE_DIR}/base
@ -571,6 +572,7 @@ target_include_directories(noexcep_operators PUBLIC
target_include_directories(ocos_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${GSL_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/include/custom_op
${PROJECT_SOURCE_DIR}/base
@ -669,7 +671,7 @@ if(OCOS_ENABLE_BLINGFIRE)
endif()
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
target_include_directories(ocos_operators PRIVATE ${nlohmann_json_SOURCE_DIR}/single_include)
target_include_directories(ocos_operators PUBLIC ${nlohmann_json_SOURCE_DIR}/single_include)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif()
@ -684,7 +686,6 @@ if(ANDROID)
list(APPEND ocos_libraries log)
endif()
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
list(APPEND ocos_libraries Microsoft.GSL::GSL)
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)
@ -705,6 +706,10 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h")
if (NOT OCOS_ENABLE_C_API)
file(GLOB shared_TARGET_LIB_C_API_SRC "shared/lib/*tokenizer*")
list(REMOVE_ITEM shared_TARGET_LIB_SRC ${shared_TARGET_LIB_C_API_SRC})
endif()
if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
add_executable(ortcustomops ${shared_TARGET_LIB_SRC})
@ -805,7 +810,12 @@ target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators
target_link_libraries(ortcustomops PUBLIC ocos_operators)
if(_BUILD_SHARED_LIBRARY)
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h")
if (OCOS_ENABLE_C_API)
list(APPEND shared_TARGET_SRC "shared/extensions_c.def")
else()
list(APPEND shared_TARGET_SRC "shared/ortcustomops.def")
endif()
add_library(extensions_shared SHARED ${shared_TARGET_SRC})
# We need to propagate OCOS_SHARED_LIBRARY if set.

Просмотреть файл

@ -12,12 +12,6 @@ struct OrtxStatus::Rep {
OrtxStatus::OrtxStatus() = default;
OrtxStatus::~OrtxStatus() = default;
OrtxStatus::OrtxStatus(extError_t code, std::string_view error_message)
: rep_(new Rep) {
rep_->code = code;
rep_->error_message = std::string(error_message);
}
OrtxStatus::OrtxStatus(extError_t code, const std::string& error_message)
: rep_(new Rep) {
rep_->code = code;
@ -58,3 +52,51 @@ OrtStatus* OrtxStatus::CreateOrtStatus() const {
OrtStatus* status = OrtW::CreateStatus(Message(), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
return status;
}
std::string OrtxStatus::ToString() const {
if (rep_ == nullptr)
return "OK";
std::string result;
switch (Code()) {
case extError_t::kOrtxOK:
result = "Success";
break;
case extError_t::kOrtxErrorInvalidArgument:
result = "Invalid argument";
break;
case extError_t::kOrtxErrorOutOfMemory:
result = "Out of Memory";
break;
case extError_t::kOrtxErrorCorruptData:
result = "Corrupt data";
break;
case extError_t::kOrtxErrorInvalidFile:
result = "Invalid data file";
break;
case extError_t::kOrtxErrorNotFound:
result = "Not found";
break;
case extError_t::kOrtxErrorAlreadyExists:
result = "Already exists";
break;
case extError_t::kOrtxErrorOutOfRange:
result = "Out of range";
break;
case extError_t::kOrtxErrorNotImplemented:
result = "Not implemented";
break;
case extError_t::kOrtxErrorInternal:
result = "Internal";
break;
case extError_t::kOrtxErrorUnknown:
result = "Unknown";
break;
default:
break;
}
result += ": ";
result += rep_->error_message;
return result;
}

Просмотреть файл

@ -11,7 +11,6 @@ struct OrtStatus;
struct OrtxStatus {
OrtxStatus();
~OrtxStatus();
OrtxStatus(extError_t code, std::string_view error_message);
OrtxStatus(extError_t code, const std::string& error_message);
OrtxStatus(const OrtxStatus& s);
OrtxStatus& operator=(const OrtxStatus& s);
@ -22,6 +21,7 @@ struct OrtxStatus {
void SetErrorMessage(const char* str);
[[nodiscard]] const char* Message() const;
[[nodiscard]] extError_t Code() const;
std::string ToString() const;
OrtStatus* CreateOrtStatus() const;

Просмотреть файл

@ -58,6 +58,24 @@ class ustring : public std::u32string {
return std::string(utf8_buf);
}
static size_t UTF8Len(char byte1) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(byte1) >> 4;
return lookup[highbits];
}
static size_t UTF8Len(char32_t codepoint) {
if (codepoint <= 0x7F) {
return 1;
} else if (codepoint <= 0x7FF) {
return 2;
} else if (codepoint <= 0xFFFF) {
return 3;
} else {
return 4;
}
}
static bool ValidateUTF8(const std::string& data) {
int cnt = 0;
for (auto i = 0; i < data.size(); i++) {

Просмотреть файл

@ -199,19 +199,22 @@ else()
endif()
endblock()
block()
file(GLOB tokenizer_TEST_SRC
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
"${TEST_SRC_DIR}/tokenizer_test/*.hpp")
if (OCOS_ENABLE_C_API)
file(GLOB tokenizer_TEST_SRC
"${TEST_SRC_DIR}/tokenizer_test/*.c"
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
"${TEST_SRC_DIR}/tokenizer_test/*.h")
add_test_target(TARGET tokenizer_api_test
TEST_SOURCES ${tokenizer_TEST_SRC}
LIBRARIES onnxruntime_extensions ${ocos_libraries}
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
add_test_target(TARGET tokenizer_api_test
TEST_SOURCES ${tokenizer_TEST_SRC}
LIBRARIES onnxruntime_extensions ${ocos_libraries}
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
endblock()
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(tokenizer_api_test PRIVATE
${PROJECT_SOURCE_DIR}/
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
endif()
endif()
endif()

45
cmake/externals/farmhash/src/farmhash.cc поставляемый
Просмотреть файл

@ -390,11 +390,14 @@ STATIC_INLINE uint32_t Mur(uint32_t a, uint32_t h) {
template <typename T> STATIC_INLINE T DebugTweak(T x) {
if (debug_mode) {
if (sizeof(x) == 4) {
x = ~Bswap32(x * c1);
} else {
x = ~Bswap64(x * k1);
}
x = ~Bswap32(x * c1);
}
return x;
}
template <> uint64_t DebugTweak(uint64_t x) {
if (debug_mode) {
x = ~Bswap64(x * k1);
}
return x;
}
@ -461,7 +464,7 @@ STATIC_INLINE uint64_t HashLen0to16(const char *s, size_t len) {
uint8_t b = s[len >> 1];
uint8_t c = s[len - 1];
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
uint32_t z = static_cast<uint32_t>(len + (static_cast<uint32_t>(c) << 2));
return ShiftMix(y * k2 ^ z * k0) * k2;
}
return k2;
@ -1002,7 +1005,7 @@ namespace farmhashnt {
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
return s == NULL ? 0 : len;
return s == NULL ? 0 : static_cast<uint32_t>(len);
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@ -1039,7 +1042,7 @@ STATIC_INLINE uint32_t Hash32Len13to24(const char *s, size_t len, uint32_t seed
uint32_t d = Fetch(s + (len >> 1));
uint32_t e = Fetch(s);
uint32_t f = Fetch(s + len - 4);
uint32_t h = d * c1 + len + seed;
uint32_t h = static_cast<uint32_t>(d * c1 + len + seed);
a = Rotate(a, 12) + f;
h = Mur(c, h) + a;
a = Rotate(a, 3) + c;
@ -1057,11 +1060,11 @@ STATIC_INLINE uint32_t Hash32Len0to4(const char *s, size_t len, uint32_t seed =
b = b * c1 + v;
c ^= b;
}
return fmix(Mur(b, Mur(len, c)));
return fmix(Mur(b, Mur(static_cast<uint32_t>(len), c)));
}
STATIC_INLINE uint32_t Hash32Len5to12(const char *s, size_t len, uint32_t seed = 0) {
uint32_t a = len, b = len * 5, c = 9, d = b + seed;
uint32_t a = static_cast<uint32_t>(len), b = static_cast<uint32_t>(len * 5), c = 9, d = b + seed;
a += Fetch(s);
b += Fetch(s + len - 4);
c += Fetch(s + ((len >> 1) & 4));
@ -1076,7 +1079,7 @@ uint32_t Hash32(const char *s, size_t len) {
}
// len > 24
uint32_t h = len, g = c1 * len, f = g;
uint32_t h = static_cast<uint32_t>(len), g = static_cast<uint32_t>(c1 * len), f = g;
uint32_t a0 = Rotate(Fetch(s + len - 4) * c1, 17) * c2;
uint32_t a1 = Rotate(Fetch(s + len - 8) * c1, 17) * c2;
uint32_t a2 = Rotate(Fetch(s + len - 16) * c1, 17) * c2;
@ -1132,7 +1135,7 @@ uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
else if (len >= 5) return Hash32Len5to12(s, len, seed);
else return Hash32Len0to4(s, len, seed);
}
uint32_t h = Hash32Len13to24(s, 24, seed ^ len);
uint32_t h = Hash32Len13to24(s, 24, seed ^ static_cast<uint32_t>(len));
return Mur(Hash32(s + 24, len - 24) + seed, h);
}
} // namespace farmhashmk
@ -1141,7 +1144,7 @@ namespace farmhashsu {
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
return s == NULL ? 0 : len;
return s == NULL ? 0 : static_cast<uint32_t>(len);
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@ -1361,7 +1364,7 @@ namespace farmhashsa {
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
return s == NULL ? 0 : len;
return s == NULL ? 0 : static_cast<uint32_t>(len);
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@ -1587,7 +1590,7 @@ STATIC_INLINE uint32_t Hash32Len13to24(const char *s, size_t len) {
uint32_t d = Fetch(s + (len >> 1));
uint32_t e = Fetch(s);
uint32_t f = Fetch(s + len - 4);
uint32_t h = len;
uint32_t h = static_cast<uint32_t>(len);
return fmix(Mur(f, Mur(e, Mur(d, Mur(c, Mur(b, Mur(a, h)))))));
}
@ -1600,11 +1603,11 @@ STATIC_INLINE uint32_t Hash32Len0to4(const char *s, size_t len) {
b = b * c1 + v;
c ^= b;
}
return fmix(Mur(b, Mur(len, c)));
return fmix(Mur(b, Mur(static_cast<uint32_t>(len), c)));
}
STATIC_INLINE uint32_t Hash32Len5to12(const char *s, size_t len) {
uint32_t a = len, b = len * 5, c = 9, d = b;
uint32_t a = static_cast<uint32_t>(len), b = static_cast<uint32_t>(len * 5), c = 9, d = b;
a += Fetch(s);
b += Fetch(s + len - 4);
c += Fetch(s + ((len >> 1) & 4));
@ -1619,7 +1622,7 @@ uint32_t Hash32(const char *s, size_t len) {
}
// len > 24
uint32_t h = len, g = c1 * len, f = g;
uint32_t h = static_cast<uint32_t>(len), g = static_cast<uint32_t>(c1 * len), f = g;
uint32_t a0 = Rotate(Fetch(s + len - 4) * c1, 17) * c2;
uint32_t a1 = Rotate(Fetch(s + len - 8) * c1, 17) * c2;
uint32_t a2 = Rotate(Fetch(s + len - 16) * c1, 17) * c2;
@ -1686,7 +1689,7 @@ uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
else if (len >= 5) return farmhashmk::Hash32Len5to12(s, len, seed);
else return farmhashmk::Hash32Len0to4(s, len, seed);
}
uint32_t h = farmhashmk::Hash32Len13to24(s, 24, seed ^ len);
uint32_t h = farmhashmk::Hash32Len13to24(s, 24, seed ^ static_cast<uint32_t>(len));
return Mur(Hash32(s + 24, len - 24) + seed, h);
}
@ -1736,7 +1739,7 @@ STATIC_INLINE uint64_t HashLen0to16(const char *s, size_t len) {
uint8_t b = s[len >> 1];
uint8_t c = s[len - 1];
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
uint32_t z = static_cast<uint32_t>(len + (static_cast<uint32_t>(c) << 2));
return ShiftMix(y * k2 ^ z * k0) * k2;
}
return k2;
@ -1775,7 +1778,7 @@ STATIC_INLINE NAMESPACE_FOR_HASH_FUNCTIONS::uint128_t CityMurmur(const char *s,
uint64_t b = Uint128High64(seed);
uint64_t c = 0;
uint64_t d = 0;
signed long l = len - 16;
signed long l = static_cast<signed long>(len) - 16;
if (l <= 0) { // len <= 16
a = ShiftMix(a * k1) * k1;
c = b * k1 + HashLen0to16(s, len);

Просмотреть файл

@ -78,7 +78,6 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
/** \brief Create a tokenizer object with the specified tokenizer path
*
* \param tokenizer Pointer to store the created tokenizer object

Просмотреть файл

@ -19,6 +19,8 @@
struct KernelBpeDecoder {
public:
virtual ~KernelBpeDecoder() = default;
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
std::string vocab;
@ -96,27 +98,26 @@ struct KernelBpeDecoder {
void BuildIdVocab(const std::string& vocab) {
arr_vocab_.reserve(vocab.size() / 2); // give a rough estimation.
std::u32string u_vocab = ustring(vocab);
std::u32string_view uv_vocab(u_vocab);
std::string_view v_vocab(vocab);
size_t last_pos = 0;
auto ccount = uv_vocab.size();
for (size_t n = 0; n < ccount; ++n) {
if (uv_vocab[n] == char32_t('\n')) {
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos);
arr_vocab_.emplace_back(ustring(s_tok));
auto c_count = v_vocab.size();
for (size_t n = 0; n < c_count; ++n) {
if (v_vocab[n] == '\n') {
std::string_view s_tok = v_vocab.substr(last_pos, n - last_pos);
arr_vocab_.emplace_back(std::string(s_tok));
last_pos = n + 1;
} else if (n == ccount - 1) {
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos + 1);
arr_vocab_.emplace_back(ustring(s_tok));
} else if (n == c_count - 1) {
std::string_view s_tok = v_vocab.substr(last_pos, n - last_pos + 1);
arr_vocab_.emplace_back(std::string(s_tok));
}
}
arr_vocab_.shrink_to_fit();
}
OrtStatusPtr Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
@ -146,10 +147,10 @@ struct KernelBpeDecoder {
}
if (added_tokens_.count(token)) {
const std::string ws = added_tokens_.at(token);
const std::string& ws = added_tokens_.at(token);
decoded_token = (std::string)ws;
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
const auto str = arr_vocab_[token];
const auto str = ustring(arr_vocab_[token]);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0) {
if (wchr <= char32_t(0xFF)) {
@ -173,6 +174,14 @@ struct KernelBpeDecoder {
decoded_token = unk_token_;
}
}
// remove the end_of_word_suffix like </w> or </s> etc.
if (end_of_word_suffix_.size() > 0) {
if (decoded_token.size() >= end_of_word_suffix_.size() &&
decoded_token.substr(decoded_token.size() - end_of_word_suffix_.size()) == end_of_word_suffix_) {
decoded_token = decoded_token.substr(0, decoded_token.size() - end_of_word_suffix_.size());
decoded_token += ' ';
}
}
if (whitespace_token_ &&
f_special && (tok_idx > 0 && !f_special_last)) {
@ -193,10 +202,10 @@ struct KernelBpeDecoder {
p_ids += seq_len;
}
output.SetStringOutput(decoded_strings, output_dim);
return nullptr;
return {};
}
private:
protected:
std::string bos_token_{"<|endoftext|>"};
std::string eos_token_{"<|endoftext|>"};
std::string unk_token_{"<|endoftext|>"};
@ -206,8 +215,9 @@ struct KernelBpeDecoder {
int64_t en_normalization_ = 0;
int64_t skip_special_tokens_ = 0;
int64_t whitespace_token_ = 0;
std::vector<ustring> arr_vocab_;
std::vector<std::string> arr_vocab_;
std::map<char32_t, unsigned char> byte_decoder_;
std::map<int64_t, std::string> added_tokens_;
std::set<int64_t> all_special_ids_;
std::string end_of_word_suffix_;
};

Просмотреть файл

@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <fstream>
#include <filesystem>
#include "ocos.h"
#include "status.h"
#include "nlohmann/json.hpp"
#include "bpe_types.h"
namespace ort_extensions::bpe {
class TokenJsonConfig final {
public:
TokenJsonConfig() {}
~TokenJsonConfig() {}
using json = nlohmann::json;
using json_pointer = nlohmann::json_pointer<std::string>;
public:
OrtxStatus Load(const std::string& json_path) {
if (json_path.empty()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
}
auto file_path = std::filesystem::path(json_path) / "tokenizer_config.json";
std::ifstream ifs(file_path);
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
}
vocab_path_ = (std::filesystem::path(json_path) / "tokenizer.json").string();
nlohmann::json json_config = nlohmann::json::parse(ifs);
add_bos_token_ = json_config.value("add_bos_token", false);
add_eos_token_ = json_config.value("add_eos_token", false);
clean_up_tokenization_spaces_ = json_config.value("clean_up_tokenization_spaces", false);
model_max_length_ = json_config.value("model_max_length", 1e+30);
tokenizer_class_ = json_config.value("tokenizer_class", "");
auto tok_iter = json_config.find("bos_token");
if (tok_iter != json_config.end() && tok_iter->is_object()) {
bos_token_ = tok_iter->value("content", "");
eos_token_ = json_config.value("/eos_token/content"_json_pointer, "");
unk_token_ = json_config.value("/unk_token/content"_json_pointer, "");
} else {
bos_token_ = json_config.value("bos_token", "");
eos_token_ = json_config.value("eos_token", "");
unk_token_ = json_config.value("unk_token", "");
}
auto pad_iter = json_config.find("pad_token");
if (pad_iter != json_config.end() && pad_iter->is_string()) {
pad_token_ = json_config.value("pad_token", "");
}
if (tokenizer_class_.empty()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get tokenizer class.");
}
return {};
}
const std::string& GetVocabDataFile() const {
return vocab_path_;
}
public:
bool add_bos_token_{};
bool add_eos_token_{};
bool clean_up_tokenization_spaces_{};
double model_max_length_{};
std::string tokenizer_class_;
std::string bos_token_;
std::string eos_token_;
std::string unk_token_;
std::string pad_token_;
private:
std::string vocab_path_;
};
} // namespace ort_extensions::bpe

Просмотреть файл

@ -1,15 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "bpe_tokenizer.hpp"
#include "bpe_kernels.h"
#include "ortx_common.h"
#include "bpe_kernels.h"
#include "bpe_json.hpp"
#include "bpe_tokenizer.hpp"
#include <optional>
#include <limits>
using namespace ort_extensions;
const char kModel_Default[] = "PreTrained";
const char kModel_GPT2[] = "GPT2";
const char kModel_CodeGen[] = "CodeGen";
const char kModel_Roberta[] = "Roberta";
@ -43,40 +45,39 @@ std::string BpeModelConf::GetSpecialTokens() const {
}
// Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
bool IsUnicodeSpace(char32_t ch) {
switch (ch) {
case 0x0009:
case 0x000A:
case 0x000B:
case 0x000C:
case 0x000D:
case 0x001C:
case 0x001D:
case 0x001E:
case 0x001F:
case 0x0020:
case 0x0085:
case 0x00A0:
case 0x1680:
case 0x2000:
case 0x2001:
case 0x2002:
case 0x2003:
case 0x2004:
case 0x2005:
case 0x2006:
case 0x2007:
case 0x2008:
case 0x2009:
case 0x200A:
case 0x2028:
case 0x2029:
case 0x202F:
case 0x205F:
case 0x3000:
return true;
}
return false;
static bool IsUnicodeSpace(char32_t ch) {
const std::set<char32_t> unicode_spaces = {
0x0009, // CHARACTER TABULATION
0x000A, // LINE FEED (LF)
0x000B, // LINE TABULATION
0x000C, // FORM FEED (FF)
0x000D, // CARRIAGE RETURN (CR)
0x001C, // FILE SEPARATOR
0x001D, // GROUP SEPARATOR
0x001E, // RECORD SEPARATOR
0x001F, // UNIT SEPARATOR
0x0020, // SPACE
0x0085, // NEXT
0x00A0, // NO-BREAK SPACE
0x1680, // OGHAM SPACE MARK
0x2000, // EN QUAD
0x2001, // EM QUAD
0x2002, // EN SPACE
0x2003, // EM SPACE
0x2004, // THREE-PER-EM SPACE
0x2005, // FOUR-PER-EM SPACE
0x2006, // SIX-PER-EM SPACE
0x2007, // FIGURE SPACE
0x2008, // PUNCTUATION SPACE
0x2009, // THIN SPACE
0x200A, // HAIR SPACE
0x2028, // LINE SEPARATOR
0x2029, // PARAGRAPH SEPARATOR
0x202F, // NARROW NO-BREAK SPACE
0x205F, // MEDIUM MATHEMATICAL SPACE
0x3000, // IDEOGRAPHIC SPACE
};
return unicode_spaces.count(ch) > 0;
}
bool AllSpaceUstring(const ustring& str) {
@ -105,7 +106,7 @@ ustring RemoveConsecutiveSpaces(const ustring& input) {
KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
: bpe_conf_(conf) {
model_name_ = conf.name_;
model_name_ = conf.name_ == nullptr ? "" : conf.name_;
};
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
@ -133,13 +134,13 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
model_name_ = model_name;
}
std::stringstream vocabu_stream(vocab);
std::stringstream vocab_stream(vocab);
std::stringstream merges_stream(merges);
bbpe_tokenizer_ = std::make_unique<BpeModel>();
auto status = bbpe_tokenizer_->Load(vocabu_stream,
auto status = bbpe_tokenizer_->Load(vocab_stream,
merges_stream,
bpe_conf_.unk_token_,
bpe_conf_.GetSpecialTokens().c_str(),
bpe_conf_.get().unk_token_,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
if (!status.IsOk()) {
return status.CreateOrtStatus();
@ -153,15 +154,14 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
}
// TODO: need to check if the special token ids are the same as the ones in HFTokenizer
unk_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.unk_token_);
if (bpe_conf_.bos_token_ != nullptr) {
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.bos_token_);
if (bpe_conf_.get().bos_token_ != nullptr) {
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().bos_token_);
}
if (bpe_conf_.eos_token_ != nullptr) {
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.eos_token_);
if (bpe_conf_.get().eos_token_ != nullptr) {
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().eos_token_);
}
if (bpe_conf_.pad_token_ != nullptr) {
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.pad_token_);
if (bpe_conf_.get().pad_token_ != nullptr) {
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().pad_token_);
}
return {};
@ -202,10 +202,23 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
return res;
}
if (IsBosEosRequired(ModelName())) {
// Add BOS token to result
bool add_bos_token = false;
if (add_bos_token_.has_value()) {
add_bos_token = add_bos_token_.value();
} else if (IsBosEosRequired(ModelName())) {
add_bos_token = true;
}
if (add_bos_token) {
res.push_back(bos_token_id_);
}
bool add_eos_token = false;
if (add_eos_token_.has_value()) {
add_eos_token = add_eos_token_.value();
} else if (IsBosEosRequired(ModelName())) {
add_eos_token = true;
}
if (ModelName() == kModel_CLIP) {
// Convert to lowercase
std::transform(input.begin(), input.end(), input.begin(), [](char32_t c) { return static_cast<char32_t>(ToLower(c)); });
@ -231,7 +244,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
OffsetMappingType offset_mapping;
if (compute_offset_mapping) {
if (IsBosEosRequired(ModelName())) {
if (add_bos_token) {
// Add offset mapping for BOS token
offset_mapping.push_back(std::make_pair(0, 0));
}
@ -300,7 +313,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}
if (compute_offset_mapping) {
if (IsBosEosRequired(ModelName())) {
if (add_eos_token) {
// Add offset mapping for EOS token
offset_mapping.emplace_back(std::make_pair(0, 0));
}
@ -309,7 +322,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}
}
if (IsBosEosRequired(ModelName())) {
if (add_eos_token) {
// Add EOS token to result
res.push_back(eos_token_id_);
}
@ -529,3 +542,95 @@ static const auto kSpmConfiguration = BpeModelConf{
SpmTokenizer::SpmTokenizer()
: KernelBpeTokenizer(kSpmConfiguration) {}
JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs(voc_file);
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
}
const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
json_conf_.name_ = model_name_.c_str();
json_conf_.bos_token_ = config.bos_token_.c_str();
json_conf_.eos_token_ = config.eos_token_.c_str();
json_conf_.unk_token_ = config.unk_token_.c_str();
json_conf_.pad_token_ = config.pad_token_.c_str();
// re-bind the configuration object
bpe_conf_ = json_conf_;
// consider to use SAX parser for large json file
nlohmann::json tok_json;
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}
bbpe_tokenizer_ = std::make_unique<BpeModel>();
auto status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
bpe::AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
added_token.lstrip_ = token.value("lstrip", false);
added_token.normalized_ = token.value("normalized", false);
added_token.rstrip_ = token.value("rstrip", false);
added_token.single_word_ = token.value("single_word", false);
added_token.special_ = token.value("special", false);
added_tokens_.emplace_back(added_token);
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
eos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.pad_token_) {
pad_token_id_ = added_token.id_;
}
}
}
if (!status.IsOk()) {
return status;
}
status = bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
if (!status.IsOk()) {
return status;
}
add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
// add_bos_token is default as false, we need to check post_processor json to see if it is true
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
auto post_processor = tok_json.find("post_processor");
if (post_processor != tok_json.end()) {
std::string text = post_processor->dump();
if (text.find(config.bos_token_) != std::string::npos) {
add_bos_token_ = true;
}
if (text.find(config.eos_token_) != std::string::npos) {
add_eos_token_ = true;
}
}
}
return status;
}
OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}

Просмотреть файл

@ -7,12 +7,15 @@
#include "status.h"
#include "ustring.h"
#include <list>
#include <string>
#include <vector>
#include <list>
#include <functional>
#include "bpe_types.h"
struct BpeModelConf {
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.
const char* unk_token_{"<|endoftext|>"};
const char* bos_token_{"<|endoftext|>"};
const char* eos_token_{"<|endoftext|>"};
@ -21,18 +24,14 @@ struct BpeModelConf {
std::string GetSpecialTokens() const;
};
namespace ort_extensions {
class BpeModel;
}
struct KernelBpeTokenizer {
KernelBpeTokenizer(const BpeModelConf& conf);
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
const std::string& ModelName() const { return model_name_; }
@ -48,25 +47,27 @@ struct KernelBpeTokenizer {
bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const;
private:
const BpeModelConf& bpe_conf_;
protected:
std::reference_wrapper<BpeModelConf const> bpe_conf_;
std::string model_name_;
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;
int64_t padding_length_ = -1;
uint32_t unk_token_id_{};
uint32_t bos_token_id_{};
uint32_t eos_token_id_{};
uint32_t pad_token_id_{};
std::optional<bool> add_bos_token_;
std::optional<bool> add_eos_token_;
};
struct GPT2Tokenizer : KernelBpeTokenizer {
GPT2Tokenizer();
// required by LiteCustomOp which needs an explicit Compute declaration for non-MSVC compiler.
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
@ -75,9 +76,9 @@ struct RobertaTokenizer : KernelBpeTokenizer {
RobertaTokenizer();
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
@ -86,9 +87,9 @@ struct CLIPTokenizer : KernelBpeTokenizer {
CLIPTokenizer();
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
@ -97,9 +98,27 @@ struct SpmTokenizer : KernelBpeTokenizer {
SpmTokenizer();
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
class JsonFastTokenizer : KernelBpeTokenizer {
public:
JsonFastTokenizer();
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
public:
auto GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
private:
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};

Просмотреть файл

@ -0,0 +1,270 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "bpe_kernels.h"
#include "bpe_decoder.hpp"
#include "bpe_json.hpp"
#include "bpe_tokenizer.hpp"
namespace ort_extensions {
struct BPEDecoderState {
bool f_special_last{};
std::string incomplete_utf8_;
};
} // namespace ort_extensions
class BpeStreamingDecoder : public KernelBpeDecoder {
public:
BpeStreamingDecoder() = default;
~BpeStreamingDecoder() override = default;
using BPEDecoderState = ort_extensions::BPEDecoderState;
// shared the data between the encoder and decoder
OrtxStatus Load(
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
const JsonFastTokenizer& encoder) {
const auto& tok_config = *ptr_config;
bos_token_ = tok_config.bos_token_;
eos_token_ = tok_config.eos_token_;
unk_token_ = tok_config.unk_token_;
auto& a_toks = encoder.GetAddedTokens();
for (const auto& tok : a_toks) {
added_tokens_[tok.id_] = tok.content_;
if (tok.special_) {
all_special_ids_.insert(tok.id_);
}
}
auto& tok_model = encoder.GetEncoder();
CreateByteDecoder(tok_model);
arr_vocab_ = tok_model.BuildDecoder();
end_of_word_suffix_ = tok_model.GetEndOfWordSuffix();
// whitespace_token_ = tok_config.clean_up_tokenization_spaces_ ? 1 : 0;
skip_special_tokens_ = 1;
// en_normalization_ = 0;
add_dummy_prefix_ = tok_config.tokenizer_class_ == "LlamaTokenizer" ? 1 : 0;
eos_token_id_ = encoder.GetEncoder().GetTokenId(tok_config.eos_token_);
tok_config_ = ptr_config;
return {};
}
static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
std::string result;
for (size_t pos = 0;; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos);
result += replace;
pos = new_pos;
}
return result;
}
static bool IsSpmByteWord(std::string_view word) {
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
}
OrtxStatus Id2Token(extTokenId_t id,
std::string& token,
bool skip_special_tokens,
bool& f_special_last) const {
bool f_special = all_special_ids_.count(id) ? true : false;
if (skip_special_tokens && f_special) {
f_special_last = f_special;
return {};
}
if (added_tokens_.count(id)) {
const std::string ws = added_tokens_.at(id);
token = (std::string)ws;
} else if (static_cast<size_t>(id) < arr_vocab_.size()) {
const auto str = ustring(arr_vocab_[id]);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0 && wchr <= 0xFF) {
token.push_back(gsl::narrow<unsigned char>(wchr));
} else {
unsigned char uchr = byte_decoder_.at(wchr);
token.push_back(uchr);
}
}
} else {
if (skip_special_tokens) {
f_special_last = f_special;
return {};
} else {
token = unk_token_;
}
}
// remove the end_of_word_suffix like </w> or </s> etc.
if (end_of_word_suffix_.size() > 0) {
if (token.size() >= end_of_word_suffix_.size() &&
token.substr(token.size() - end_of_word_suffix_.size()) == end_of_word_suffix_) {
token = token.substr(0, token.size() - end_of_word_suffix_.size());
token += ' ';
}
}
f_special_last = f_special;
return {};
}
OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
const char spm_underscore[] = "\xe2\x96\x81";
std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
bool f_special = false;
if (piece.empty() || all_special_ids_.count(id)) {
token = "";
f_special = true;
} else if (IsSpmByteWord(piece)) {
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
token = {static_cast<char>(strtol(buf, NULL, 16))};
} else {
token = ReplaceAll(piece, spm_underscore, " ");
}
if (!token.empty() && token[0] == ' ' && f_special_last && add_dummy_prefix_) {
token = token.substr(1);
}
f_special_last = f_special;
return {};
}
static bool IsSpmTokenizer(const std::string& tok_class) {
return tok_class == "GemmaTokenizer" || tok_class == "LlamaTokenizer";
}
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
auto bpe_state = *state;
std::unique_ptr<BPEDecoderState> bpe_state_ptr;
bool is_first = false;
if (bpe_state == nullptr) {
bpe_state_ptr = std::make_unique<BPEDecoderState>();
bpe_state = bpe_state_ptr.get();
is_first = true;
}
bool f_special = bpe_state->f_special_last; // [Spm]Id2Token needs the last state
bool f_special_last = bpe_state->f_special_last;
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
? SpmId2Token(id, token, f_special)
: Id2Token(id, token, true /* tok_config_.skip_special_tokens_ */, f_special);
if (status.IsOk()) {
if (bpe_state_ptr) {
*state = bpe_state_ptr.release();
}
if (tok_config_->clean_up_tokenization_spaces_) {
if (f_special && (is_first && !f_special_last)) {
token = std::string(" ") + token;
}
if (f_special && id != eos_token_id_) {
token.push_back(' ');
}
} // end case of whitespace_token_
if (!bpe_state->incomplete_utf8_.empty()) {
token = bpe_state->incomplete_utf8_ + token;
bpe_state->incomplete_utf8_.clear();
} else {
if (!token.empty() && ustring::UTF8Len(token.front()) > token.size()) {
bpe_state->incomplete_utf8_ = token;
token = "";
}
}
}
bpe_state->f_special_last = f_special;
return status;
}
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
if (ids_dim.size() > 1) {
output_dim.resize(ids_dim.size() - 1);
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}
size_t seq_len = ids_dim.back();
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
decoded_strings.reserve(string_batch);
for (auto n = string_batch; n > 0; n--) {
bool f_special_last = false;
std::string text;
for (size_t tok_idx = 0; tok_idx < seq_len; ++tok_idx) {
const auto id = ort_extensions::narrow<extTokenId_t>(*(p_ids + tok_idx));
std::string decoded_token;
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
? SpmId2Token(id, decoded_token, f_special_last)
: Id2Token(id, decoded_token, true, f_special_last);
if (!status.IsOk()) {
return status;
}
bool f_special = all_special_ids_.count(id) ? true : false;
if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
text.push_back(' ');
}
text.append(decoded_token);
if (whitespace_token_ && f_special && tok_idx != seq_len - 1) {
text.push_back(' ');
}
}
if (tok_config_->tokenizer_class_.find("CLIP") == 0 && !text.empty() && text.back() == ' ') {
text.pop_back();
}
decoded_strings.emplace_back(std::move(text));
p_ids += seq_len;
}
output.SetStringOutput(decoded_strings, output_dim);
return {};
}
private:
void CreateByteDecoder(const ort_extensions::BpeModel& /* bpe_model */) {
char32_t index = 256;
for (char32_t i = 0; i < 256; ++i) {
/*
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
*/
if ((/* i >= 0 && */ i < 33) || (i >= 127 && i < 161) || (i == 173)) {
byte_decoder_[index++] = gsl::narrow<unsigned char>(i);
} else {
byte_decoder_[i] = gsl::narrow<unsigned char>(i);
}
}
}
private:
extTokenId_t eos_token_id_{0};
bool add_dummy_prefix_ = false;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
};

Просмотреть файл

@ -19,10 +19,13 @@
#include "nlohmann/json.hpp"
#include "bpe_utils.hpp"
#include "trietree.hpp"
#include "bpe_types.h"
namespace ort_extensions {
class BpeModel {
using json = nlohmann::json;
public:
BpeModel() = default;
@ -49,7 +52,7 @@ class BpeModel {
bool spm_converted) {
nlohmann::json tok_json;
vocab_stream >> tok_json;
vocab_map_ = std::move(tok_json.get<std::unordered_map<std::string, uint32_t>>());
tok_json.get_to(vocab_map_);
auto it = vocab_map_.find(unk_token);
if (it != vocab_map_.end()) {
@ -123,6 +126,75 @@ class BpeModel {
return {};
}
OrtxStatus Load(const json& bpe_model,
const char* /* special_tokens */,
bool spm_converted) {
const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
auto it = bpe_model.find("unk_token");
if (it != bpe_model.end() && it->is_string()) {
auto ukt = it->get<std::string>();
auto it_word = vocab_map_.find(ukt);
if (it_word != vocab_map_.end()) {
unk_id_ = it_word->second;
}
}
it = bpe_model.find("end_of_word_suffix");
if (it != bpe_model.end() && it->is_string()) {
end_of_word_suffix_ = it->get<std::string>();
}
if (spm_converted) {
UpdateSpmByteToken(vocab_map_);
} else {
CreateByteEncoder();
}
uint32_t index = 0;
auto merge_item = merges_json.begin();
while (merge_item != merges_json.end()) {
std::string line = merge_item.value();
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
if (line.empty()) continue;
if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' ');
if (pos == std::string::npos) {
return {
kOrtxErrorCorruptData,
"Cannot know how to parse line: " + line,
};
}
std::string w1 = line.substr(0, pos);
std::string w2 = line.substr(pos + 1);
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);
BpeNode value{iww, index++, token_length};
bpe_rank_[GetRankKey(iw1, iw2)] = value;
merge_item++;
}
id2token_map_.resize(vocab_map_.size());
for (const auto& [t, i] : vocab_map_) {
if (i > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
continue; // safe purpose.
}
if (i > id2token_map_.size()) {
id2token_map_.resize(static_cast<size_t>(i) + 1);
}
id2token_map_[i] = t;
}
return {};
}
OrtxStatus LoadAddedTokens(const char* added_tokens) {
int id = bpe::kInvalidTokenId;
std::istringstream strm_tokens(added_tokens);
@ -149,6 +221,18 @@ class BpeModel {
return {};
}
OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}
return {};
}
std::vector<std::string> BuildDecoder() const {
return id2token_map_;
}
// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
// split by added tokens
@ -223,15 +307,19 @@ class BpeModel {
return byte_encoder_;
}
uint32_t GetTokenId(const std::string& key) {
uint32_t GetTokenId(const std::string& key) const {
auto it = vocab_map_.find(key);
if (it != end(vocab_map_)) {
if (it != vocab_map_.end()) {
return it->second;
} else {
return unk_id_;
}
}
const std::string& GetEndOfWordSuffix() const {
return end_of_word_suffix_;
}
private:
struct BpeNode {
uint32_t id;
@ -260,6 +348,7 @@ class BpeModel {
}
private:
std::string end_of_word_suffix_;
std::map<uint64_t, BpeNode> bpe_rank_;
uint32_t byte_encoder_[256] = {};

Просмотреть файл

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
namespace ort_extensions {
class BpeModel;
namespace bpe {
struct AddedToken final {
uint32_t id_{};
std::string token_type_;
std::string content_;
bool lstrip_{};
bool normalized_{};
bool rstrip_{};
bool single_word_{};
bool special_{};
};
class TokenJsonConfig; // forward declaration
} // namespace bpe
} // namespace ort_extensions

18
shared/extensions_c.def Normal file
Просмотреть файл

@ -0,0 +1,18 @@
LIBRARY "ortextensions.dll"
EXPORTS
RegisterCustomOps @1
AddExternalCustomOp @2
GetActiveOrtAPIVersion @3
OrtxGetAPIVersion @4
OrtxGetLastErrorMessage @5
OrtxCreate @6
OrtxDispose @7
OrtxCreateTokenizer @8
OrtxTokenize @9
OrtxDetokenize @10
OrtxDetokenize1D @11
OrtxDetokenizeCached @12
OrtxStringArrayGetBatch @13
OrtxStringArrayGetItem @14
OrtxTokenId2DArrayGetBatch @15
OrtxTokenId2DArrayGetItem @16

Просмотреть файл

@ -0,0 +1,385 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cstdarg>
#include <filesystem>
#include <algorithm>
#include "tokenizer_impl.h"
namespace ort_extensions {
class TokenId2DArray : public OrtxObjectImpl {
public:
TokenId2DArray() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenId2DArray) {}
~TokenId2DArray() override = default;
void SetTokenIds(std::vector<std::vector<extTokenId_t>>&& token_ids) {
token_ids_ = token_ids;
}
[[nodiscard]] const std::vector<std::vector<extTokenId_t>>& token_ids() const {
return token_ids_;
}
private:
std::vector<std::vector<extTokenId_t>> token_ids_;
};
class StringArray : public OrtxObjectImpl {
public:
StringArray() : OrtxObjectImpl(extObjectKind_t::kOrtxKindStringArray) {}
~StringArray() override = default;
void SetStrings(std::vector<std::string>&& strings) {
strings_ = strings;
}
[[nodiscard]] const std::vector<std::string>& strings() const {
return strings_;
}
private:
std::vector<std::string> strings_;
};
class DetokenizerCache : public OrtxObjectImpl {
public:
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
~DetokenizerCache() override = default;
std::unique_ptr<BPEDecoderState> decoder_state_{};
std::string last_text_{}; // last detokenized text
};
} // namespace ort_extensions
using namespace ort_extensions;
thread_local std::string last_error_message;
OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const {
if (ext_kind_ != static_cast<int>(kind)) {
return {extError_t::kOrtxErrorInvalidArgument,
"Object is not an instance of the requested type"};
}
return {};
}
struct ReturnableStatus {
ReturnableStatus() = default;
ReturnableStatus(OrtxStatus&& status) : status_(status) {}
~ReturnableStatus() {
if (!status_.IsOk()) {
last_error_message = status_.Message();
}
}
ReturnableStatus& operator=(OrtxStatus&& status) {
status_ = status;
return *this;
}
bool IsOk() const { return status_.IsOk(); }
extError_t Code() const { return status_.Code(); }
private:
OrtxStatus status_;
};
int ORTX_API_CALL OrtxGetAPIVersion() {
return API_VERSION;
}
const char* OrtxGetLastErrorMessage() {
return last_error_message.c_str();
}
extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, ...) {
if (object == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
if (kind == extObjectKind_t::kOrtxKindUnknown) {
return kOrtxErrorInvalidArgument;
}
va_list args;
va_start(args, object);
if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) {
*object = std::make_unique<DetokenizerCache>().release();
} else if (kind == extObjectKind_t::kOrtxKindTokenizer) {
return OrtxCreateTokenizer(static_cast<OrtxTokenizer**>(object), va_arg(args, const char*));
}
va_end(args);
return extError_t();
}
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer,
const char* tokenizer_path) {
// test if the tokenizer_path is a valid directory
if (tokenizer_path == nullptr) {
last_error_message = "The tokenizer data directory is null";
return kOrtxErrorInvalidArgument;
}
if (!std::filesystem::is_directory(tokenizer_path)) {
last_error_message = std::string("Cannot find the directory of ") + tokenizer_path;
return kOrtxErrorInvalidArgument;
}
ReturnableStatus status;
// auto ptr = ort_extensions::CreateTokenizer(tokenizer_path, "", &status);
auto ptr = std::make_unique<ort_extensions::TokenizerImpl>();
status = ptr->Load(tokenizer_path);
if (status.IsOk()) {
*tokenizer = static_cast<OrtxTokenizer*>(ptr.release());
return extError_t();
}
return status.Code();
}
template <typename T>
void Dispose(T* object) {
auto token_ptr = static_cast<T*>(object);
std::unique_ptr<T> ptr(token_ptr);
ptr.reset();
}
extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) {
if (object == nullptr || *object == nullptr) {
return kOrtxErrorInvalidArgument;
}
auto Ortx_object = static_cast<OrtxObjectImpl*>(*object);
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindUnknown) {
return kOrtxErrorInvalidArgument;
}
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
Dispose(static_cast<ort_extensions::StringArray*>(*object));
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) {
Dispose(static_cast<ort_extensions::TokenId2DArray*>(*object));
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) {
Dispose(static_cast<ort_extensions::DetokenizerCache*>(*object));
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenizer) {
Dispose(static_cast<ort_extensions::TokenizerImpl*>(*object));
}
*object = nullptr;
return extError_t();
}
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
return OrtxDispose(&object);
}
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
const char* input[], size_t batch_size, OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status =
token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
if (!status.IsOk()) {
return status.Code();
}
std::vector<std::vector<extTokenId_t>> t_ids;
std::vector<std::string_view> input_view;
std::transform(input, input + batch_size, std::back_inserter(input_view),
[](const char* str) { return std::string_view(str); });
status = token_ptr->Tokenize(input_view, t_ids);
if (!status.IsOk()) {
return status.Code();
}
auto result = std::make_unique<ort_extensions::TokenId2DArray>().release();
result->SetTokenIds(std::move(t_ids));
*output = static_cast<OrtxTokenId2DArray*>(result);
return extError_t();
}
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
const OrtxTokenId2DArray* input, OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
if (!status.IsOk()) {
return status.Code();
}
auto input_2d = static_cast<const TokenId2DArray*>(input);
status = input_2d->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray);
if (!status.IsOk()) {
return status.Code();
}
std::vector<span<extTokenId_t const>> t_ids;
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(),
std::back_inserter(t_ids),
[](const std::vector<extTokenId_t>& vec) {
return span<extTokenId_t const>(vec.data(), vec.size());
});
std::vector<std::string> output_text;
status = token_ptr->Detokenize(t_ids, output_text);
if (!status.IsOk()) {
return status.Code();
}
auto result = std::make_unique<ort_extensions::StringArray>().release();
result->SetStrings(std::move(output_text));
*output = static_cast<OrtxStringArray*>(result);
return extError_t();
;
}
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer,
const extTokenId_t* input,
size_t len,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
if (!status.IsOk()) {
return status.Code();
}
std::vector<span<extTokenId_t const>> t_ids = {{input, len}};
std::vector<std::string> output_text;
status = token_ptr->Detokenize(t_ids, output_text);
if (!status.IsOk()) {
return status.Code();
}
auto result = std::make_unique<ort_extensions::StringArray>().release();
result->SetStrings(std::move(output_text));
*output = static_cast<OrtxStringArray*>(result);
return extError_t();
}
extError_t ORTX_API_CALL OrtxStringArrayGetBatch(const OrtxStringArray* string_array, size_t* length) {
if (string_array == nullptr || length == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const StringArray*>(string_array);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindStringArray));
if (!status.IsOk()) {
return status.Code();
}
*length = token_ptr->strings().size();
return extError_t();
}
extError_t ORTX_API_CALL OrtxStringArrayGetItem(const OrtxStringArray* string_array, size_t index, const char** item) {
if (string_array == nullptr || item == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const StringArray*>(string_array);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindStringArray));
if (!status.IsOk()) {
return status.Code();
}
if (index >= token_ptr->strings().size()) {
last_error_message = "the index is out of range";
return kOrtxErrorInvalidArgument;
}
*item = token_ptr->strings()[index].c_str();
return extError_t();
}
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* token_id_2d_array, size_t* length) {
if (token_id_2d_array == nullptr || length == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_2d_ptr = static_cast<const TokenId2DArray*>(token_id_2d_array);
ReturnableStatus status(token_2d_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray));
if (!status.IsOk()) {
return status.Code();
}
*length = token_2d_ptr->token_ids().size();
return extError_t();
}
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array,
size_t index, const extTokenId_t** item, size_t* length) {
if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const TokenId2DArray*>(token_id_2d_array);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray));
if (!status.IsOk()) {
return status.Code();
}
if (index >= token_ptr->token_ids().size()) {
last_error_message = "the index is out of range";
return kOrtxErrorInvalidArgument;
}
*item = token_ptr->token_ids()[index].data();
*length = token_ptr->token_ids()[index].size();
return extError_t();
}
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer,
OrtxDetokenizerCache* cache,
extTokenId_t next_id, const char** text_out) {
if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) {
last_error_message = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
if (!status.IsOk()) {
return status.Code();
}
auto cache_ptr = static_cast<DetokenizerCache*>(cache);
status = ReturnableStatus(cache_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindDetokenizerCache));
if (!status.IsOk()) {
return status.Code();
}
cache_ptr->last_text_.clear();
status = ReturnableStatus(token_ptr->Id2Token(next_id, cache_ptr->last_text_, cache_ptr->decoder_state_));
if (status.IsOk()) {
*text_out = cache_ptr->last_text_.c_str();
}
return status.Code();
}

Просмотреть файл

Просмотреть файл

@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "bpe_kernels.h"
#include "bpe_tokenizer.hpp"
#include "bpe_decoder.hpp"
#include "tokenizer_impl.h"
using namespace ort_extensions;
class SimpleAllocator : public ortc::IAllocator {
public:
void* Alloc(size_t size) override {
return malloc(size);
}
void Free(void* p) override {
if (p) {
free(p);
}
}
};
static SimpleAllocator g_allocator;
TokenizerImpl::TokenizerImpl() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer){};
TokenizerImpl::~TokenizerImpl(){};
OrtxStatus TokenizerImpl::Load(const std::string& dir) {
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
auto status = tok_config_->Load(dir);
if (!status.IsOk()) {
return status;
}
tokenizer_ = std::make_unique<JsonFastTokenizer>();
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
}
return status;
}
OrtxStatus TokenizerImpl::BatchEncode(
const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
for (const auto& s : input) {
ortc::Tensor<int64_t> ts_output(&g_allocator);
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
if (!status.IsOk()) {
return status;
}
std::vector<extTokenId_t> ids(ts_output.NumberOfElement());
std::transform(ts_output.Data(), ts_output.Data() + ts_output.NumberOfElement(), ids.begin(),
[](int64_t v) { return static_cast<extTokenId_t>(v); });
t_ids.emplace_back(std::move(ids));
}
return {};
}
OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids,
std::vector<std::string>& t_text) const {
for (const auto& s : t_ids) {
std::vector<int64_t> ids(s.size());
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
ortc::Tensor<std::string> ts_output;
OrtxStatus status = detokenizer_->Compute(ts_input, ts_output);
if (!status.IsOk()) {
return status;
}
t_text.emplace_back(ts_output.AsScalar());
}
return {};
}
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
return detokenizer_->Id2Token(id, token, state);
}

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ortx_tokenizer.h"
#include "bpe_kernels.h"
#include "bpe_json.hpp"
#include "bpe_streaming.hpp"
namespace ort_extensions {
class OrtxObjectImpl : public OrtxObject {
public:
explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() {
ext_kind_ = static_cast<int>(kind);
};
virtual ~OrtxObjectImpl() = default;
[[nodiscard]] OrtxStatus IsInstanceOf(extObjectKind_t kind) const;
[[nodiscard]] extObjectKind_t ortx_kind() const {
if (ext_kind_ < static_cast<int>(extObjectKind_t::kOrtxKindBegin) ||
ext_kind_ >= static_cast<int>(extObjectKind_t::kOrtxKindEnd)) {
return extObjectKind_t::kOrtxKindUnknown;
}
return static_cast<extObjectKind_t>(ext_kind_);
}
};
template <typename T>
class span {
public:
using value_type = std::remove_cv_t<T>;
span(T* d, size_t s) : data_(d), size_(s) {}
span(std::vector<value_type>& v) {
data_ = v.data();
size_ = v.size();
}
T* data() const { return data_; }
[[nodiscard]] size_t size() const { return size_; }
T* begin() const { return data_; }
T* end() const { return data_ + size_; }
private:
T* data_;
size_t size_;
};
class TokenizerImpl : public OrtxObjectImpl {
public:
TokenizerImpl();
virtual ~TokenizerImpl();
public:
OrtxStatus Load(const std::string& dir);
OrtxStatus Tokenize(const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
return BatchEncode(input, t_ids);
}
OrtxStatus Detokenize(const std::vector<span<extTokenId_t const>>& t_ids,
std::vector<std::string>& t_text) const {
return BatchDecode(t_ids, t_text);
}
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
BPEDecoderState* state_ptr = cache.get();
OrtxStatus status = Id2Token(id, token, &state_ptr);
if (status.IsOk()) {
if (state_ptr != cache.get()) {
cache.reset(state_ptr);
}
}
return status;
}
OrtxStatus BatchEncode(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const;
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state ) const;
private:
std::string tokenizer_dir_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<JsonFastTokenizer> tokenizer_;
std::unique_ptr<BpeStreamingDecoder> detokenizer_;
};
} // namespace ort_extensions

98393
test/data/clip/tokenizer.json Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,34 @@
{
"unk_token": {
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": true,
"__type": "AddedToken"
},
"bos_token": {
"content": "<|startoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": true,
"__type": "AddedToken"
},
"eos_token": {
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": true,
"__type": "AddedToken"
},
"pad_token": "<|endoftext|>",
"add_prefix_space": false,
"errors": "replace",
"do_lower_case": true,
"name_or_path": "openai/clip-vit-base-patch32",
"model_max_length": 77,
"special_tokens_map_file": "/home/suraj/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
"tokenizer_class": "CLIPTokenizer"
}

Просмотреть файл

@ -0,0 +1,46 @@
{
"additional_special_tokens": [
{
"content": "<start_of_turn>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<end_of_turn>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
],
"bos_token": {
"content": "<bos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<eos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

836758
test/data/gemma/tokenizer.json Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,70 @@
{
"add_bos_token": true,
"add_eos_token": false,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<eos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<bos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"3": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"106": {
"content": "<start_of_turn>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"107": {
"content": "<end_of_turn>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<start_of_turn>",
"<end_of_turn>"
],
"bos_token": "<bos>",
"chat_template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"clean_up_tokenization_spaces": false,
"eos_token": "<eos>",
"legacy": null,
"model_max_length": 1000000000000000019884624838656,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "GemmaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false
}

Просмотреть файл

@ -0,0 +1,23 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,33 @@
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

100647
test/data/phi-2/tokenizer.json Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,323 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"50256": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"50257": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50258": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50259": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50260": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50261": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50262": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50263": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50264": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50265": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50266": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50267": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50268": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50269": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50270": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50271": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50272": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50273": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50274": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50275": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50276": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50277": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50278": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50279": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50280": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50281": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50282": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50283": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50284": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50285": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50286": {
"content": " ",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50287": {
"content": "\t\t\t\t\t\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50288": {
"content": "\t\t\t\t\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50289": {
"content": "\t\t\t\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50290": {
"content": "\t\t\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50291": {
"content": "\t\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50292": {
"content": "\t\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50293": {
"content": "\t\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
},
"50294": {
"content": "\t\t",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": false
}
},
"bos_token": "<|endoftext|>",
"clean_up_tokenization_spaces": true,
"eos_token": "<|endoftext|>",
"model_max_length": 2048,
"tokenizer_class": "CodeGenTokenizer",
"unk_token": "<|endoftext|>"
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -12,7 +12,11 @@ void convert_test(const char* const_str) {
std::string string(const_str);
const std::string const_string(const_str);
#if _WIN32
auto str = std::shared_ptr<char>(_strdup(const_str));
#else
auto str = std::shared_ptr<char>(strdup(const_str));
#endif
ustring char_construct(str.get());
EXPECT_EQ(const_string, std::string(char_construct));

Просмотреть файл

@ -89,7 +89,7 @@ class TestMathOpString(unittest.TestCase):
inv_mat = np.linalg.inv(mat)
ort_inv = OrtPyFunction.from_customop('Inverse')
act_mat = ort_inv(mat)
self.assertTrue(np.allclose(inv_mat, act_mat, rtol=1.e-3))
np.testing.assert_allclose(inv_mat, act_mat, rtol=1.e-2, atol=1.e-3)
if __name__ == "__main__":

Просмотреть файл

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include <locale>
#include "gtest/gtest.h"
#include "ocos.h"
#include "ortx_tokenizer.h"
#include "bpe_kernels.h"
TEST(bbpe_tokenizer, test_encoder) {
EXPECT_EQ(0, ORT_OK);
}

Просмотреть файл

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/* C-only file, to verify the header file C compatibility */
#include <stdio.h>
#include <string.h>
#include "c_only_test.h"
#if defined(_WIN32)
#define strdup _strdup
#endif
extError_t tokenize_text(OrtxTokenizer* tokenizer, const char* text, char** decoded_text) {
OrtxTokenId2DArray* tok_2d_output = NULL;
const char* tok_input[] = {text};
extError_t err = OrtxTokenize(tokenizer, tok_input, 1, &tok_2d_output);
if (err != kOrtxOK) {
return err;
}
size_t length = 0;
const extTokenId_t* token_ids = NULL;
OrtxTokenId2DArrayGetItem(tok_2d_output, 0, &token_ids, &length);
OrtxStringArray* detok_output = NULL;
err = OrtxDetokenize1D(tokenizer, token_ids, length, &detok_output);
if (err != kOrtxOK) {
ORTX_DISPOSE(tok_2d_output);
return err;
}
const char* decoded_str = NULL;
OrtxStringArrayGetItem(detok_output, 0, &decoded_str);
*decoded_text = strdup(decoded_str);
ORTX_DISPOSE(tok_2d_output);
ORTX_DISPOSE(detok_output);
return kOrtxOK;
}

Просмотреть файл

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ortx_tokenizer.h"
#ifdef __cplusplus
extern "C"
#endif // __cplusplus
extError_t tokenize_text(OrtxTokenizer* tokenizer, const char* text, char** decoded_text);

Просмотреть файл

@ -0,0 +1,271 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include <locale>
#include "gtest/gtest.h"
#include "ocos.h"
#include "c_only_test.h"
#include "shared/lib/tokenizer_impl.h"
static void DumpTokenIds(const std::vector<std::vector<extTokenId_t>>& token_ids) {
#ifdef _DEBUG
for (const auto& tokens : token_ids) {
for (const auto& token : tokens) {
std::cout << token << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
#endif
}
TEST(CApiTest, ApiTest) {
int ver = OrtxGetAPIVersion();
EXPECT_GT(ver, 0);
OrtxTokenizer* tokenizer = NULL;
extError_t err = OrtxCreateTokenizer(&tokenizer, "data/tiktoken");
EXPECT_EQ(err, kOrtxOK);
const char* input = "This is a test";
char* decoded_text = NULL;
err = tokenize_text(tokenizer, input, &decoded_text);
EXPECT_EQ(err, kOrtxOK);
EXPECT_STREQ(decoded_text, input);
free(decoded_text);
}
TEST(CApiTest, StreamApiTest) {
OrtxTokenizer* tokenizer = NULL;
extError_t err = OrtxCreate(kOrtxKindTokenizer, &tokenizer, "data/llama2");
EXPECT_EQ(err, kOrtxOK);
OrtxDetokenizerCache* detok_cache = NULL;
err = OrtxCreate(kOrtxKindDetokenizerCache, &detok_cache);
EXPECT_EQ(err, kOrtxOK);
extTokenId_t token_ids[] = {1, 910, 338, 263, 1243, 322, 278, 1473, 697, 29889, 29871, 35};
for (size_t i = 0; i < sizeof(token_ids) / sizeof(token_ids[0]); i++) {
const char* token = NULL;
err = OrtxDetokenizeCached(tokenizer, detok_cache, token_ids[i], &token);
#ifdef _DEBUG
std::cout << token;
#endif
EXPECT_EQ(err, kOrtxOK);
}
#ifdef _DEBUG
std::cout << std::endl;
#endif
OrtxDisposeOnly(detok_cache);
OrtxDispose(&tokenizer);
}
TEST(OrtxTokenizerTest, ClipTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/clip");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
std::vector<std::string_view> input = {"this is a test", "the second one"};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), 2);
EXPECT_EQ(token_ids[0].size(), 6);
EXPECT_EQ(token_ids[1].size(), 5);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, TicTokenTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/tiktoken");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
std::vector<extTokenId_t> EXPECTED_IDS_0 = {128000, 2028, 374, 264, 1296};
std::vector<std::string_view> input = {"This is a test", "the second one"};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), 2);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
EXPECT_EQ(token_ids[1].size(), 4);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, GemmaTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/gemma");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
std::vector<std::string_view> input = {
"I like walking my cute dog\n and\x17 then",
"生活的真谛是",
"\t\t\t\t \n\n61",
"Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2, 235285, 1154, 10350, 970, 9786, 5929, 108, 578, 240, 1492};
std::vector<extTokenId_t> EXPECTED_IDS_1 = {2, 122182, 235710, 245467, 235427};
std::vector<extTokenId_t> EXPECTED_IDS_2 = {2, 255971, 235248, 109, 235318, 235274};
std::vector<extTokenId_t> EXPECTED_IDS_3 = {2, 6750, 1, 235265, 235248, 255969, 235248, 109, 4747, 139, 235335, 139,
216311, 241316, 139, 239880, 235341, 144, 235269, 235248, 235274, 235284,
235304, 235310, 235248, 235274, 235308, 235248, 235308, 235269, 235318, 235274};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), input.size());
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
EXPECT_EQ(token_ids[1], EXPECTED_IDS_1);
EXPECT_EQ(token_ids[2], EXPECTED_IDS_2);
EXPECT_EQ(token_ids[3], EXPECTED_IDS_3);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {
EXPECTED_IDS_0, EXPECTED_IDS_1, EXPECTED_IDS_2, EXPECTED_IDS_3};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
// std::cout << out_text[0] << std::endl;
// std::cout << out_text[1] << std::endl;
// std::cout << out_text[2] << std::endl;
EXPECT_EQ(out_text[0], input[0]);
EXPECT_EQ(out_text[1], input[1]);
}
static const char* kPromptText = R"(```python
def print_prime(n):
"""
Print all primes between 1 and n
"""
primes = []
for num in range(2, n+1):
is_prime = True
for i in range(2, int(math.sqrt(num))+1):
if num % i == 0:
is_prime = False
break
if is_prime:
primes.append(num)
print(primes)''')";
TEST(OrtxTokenizerTest, CodeGenTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/phi-2");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
const char* prompt_text = kPromptText;
std::vector<std::string_view> input = {prompt_text};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), 1);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0]};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
// std::cout << out_text[0] << std::endl;
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/phi-2");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
const char* prompt_text = kPromptText;
std::vector<std::string_view> input = {prompt_text};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), 1);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
token_ids[0] = {921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
for (const auto& token_id : token_ids[0]) {
std::string token;
status = tokenizer->Id2Token(token_id, token, decoder_cache);
EXPECT_TRUE(status.IsOk());
// std::cout << token;
text.append(token);
}
// EXPECT_EQ(text, input[0]);
}
TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
// test the llama2 tokenizer with BPE class, instead of sentencepiece wrapper.
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/llama2");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_TRUE(tokenizer != nullptr);
std::vector<std::string_view> input = {"This is a test and the second one. "};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
// Add an extra byte token for decoding tests
token_ids[0].push_back(35); // <0x20>
DumpTokenIds(token_ids);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
// std::cout << "\"";
for (const auto& token_id : token_ids[0]) {
std::string token;
auto status = tokenizer->Id2Token(token_id, token, decoder_cache);
EXPECT_TRUE(status.IsOk());
// std::cout << token;
text.append(token);
}
// std::cout << "\"" << std::endl;
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
}

Просмотреть файл

@ -19,7 +19,7 @@ void FixCurrentDir(const std::string& init_path = "") {
auto cur_dir = std::filesystem::current_path();
if (!init_path.empty()) {
std::filesystem::path init_dir = init_path;
cur_dir = init_dir.parent_path();
cur_dir = std::filesystem::absolute(init_dir).parent_path();
}
do {