Add TrieTokenizer for RWKV-like LLM models (#509)
* Add TrieTokenizer for RWKV-like LLM models * add more tests * fix the windows build * downloading file instead of check in the vocab file * a small bug fixing
This commit is contained in:
Родитель
c8bb9e8abd
Коммит
978ada6d60
|
@ -48,6 +48,7 @@ option(OCOS_ENABLE_CPP_EXCEPTIONS "Enable C++ Exception" ON)
|
|||
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
|
||||
option(OCOS_ENABLE_RE2_REGEX "Enable StringRegexReplace and StringRegexSplit" ON)
|
||||
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
|
||||
option(OCOS_ENABLE_TRIE_TOKENIZER "Enable the TrieTokenizer building" ON)
|
||||
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
|
||||
option(OCOS_ENABLE_WORDPIECE_TOKENIZER "Enable the WordpieceTokenizer building" ON)
|
||||
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
|
||||
|
@ -72,6 +73,7 @@ function(disable_all_operators)
|
|||
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_TRIE_TOKENIZER OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_BLINGFIRE OFF CACHE INTERNAL "" FORCE)
|
||||
|
@ -346,6 +348,13 @@ if(OCOS_ENABLE_GPT2_TOKENIZER)
|
|||
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_TRIE_TOKENIZER)
|
||||
# Trie Tokenizer
|
||||
set(_HAS_TOKENIZER ON)
|
||||
file(GLOB tok_TARGET_SRC "operators/tokenizer/trie_tokenizer.hpp" "operators/tokenizer/unescape.h")
|
||||
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_SPM_TOKENIZER)
|
||||
# SentencePiece
|
||||
set(_HAS_TOKENIZER ON)
|
||||
|
@ -512,6 +521,10 @@ if(OCOS_ENABLE_GPT2_TOKENIZER)
|
|||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_TRIE_TOKENIZER)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TRIE_TOKENIZER)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER)
|
||||
endif()
|
||||
|
@ -797,3 +810,30 @@ if(OCOS_ENABLE_CTEST)
|
|||
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
|
||||
add_dependencies(ocos_operators triton)
|
||||
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
|
||||
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
|
||||
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x86")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
|
||||
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
|
||||
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
|
||||
|
||||
else()
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
|
||||
|
||||
endif() #if (WIN32)
|
||||
|
||||
endif()
|
||||
|
|
|
@ -29,11 +29,16 @@ target_compile_definitions(extensions_pydll PRIVATE
|
|||
|
||||
target_link_libraries(extensions_pydll PRIVATE Python3::Module ocos_operators)
|
||||
|
||||
if(NOT "${OCOS_EXTENTION_NAME}" STREQUAL "")
|
||||
if(OCOS_PYTHON_MODULE_PATH)
|
||||
get_filename_component(OCOS_PYTHON_MODULE_NAME ${OCOS_PYTHON_MODULE_PATH} NAME)
|
||||
if(NOT WIN32)
|
||||
set_target_properties(extensions_pydll PROPERTIES
|
||||
LIBRARY_OUTPUT_NAME ${OCOS_EXTENTION_NAME}
|
||||
LIBRARY_OUTPUT_NAME ${OCOS_PYTHON_MODULE_NAME}
|
||||
PREFIX ""
|
||||
SUFFIX "")
|
||||
endif()
|
||||
|
||||
add_custom_command(TARGET extensions_pydll POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:extensions_pydll> ${OCOS_PYTHON_MODULE_PATH}
|
||||
COMMENT "Copying $<TARGET_FILE:extensions_pydll> to ${OCOS_PYTHON_MODULE_PATH}")
|
||||
endif()
|
||||
|
|
|
@ -350,6 +350,26 @@ class SentencepieceDecoder(CustomOp):
|
|||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
|
||||
|
||||
|
||||
class TrieTokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
|
||||
|
||||
|
||||
class TrieDetokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
|
||||
|
||||
|
||||
class Inverse(CustomOp):
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -27,42 +27,44 @@
|
|||
#include "bert_tokenizer_decoder.hpp"
|
||||
#endif
|
||||
|
||||
const std::vector<const OrtCustomOp*>& TokenizerLoader() {
|
||||
#ifdef ENABLE_TRIE_TOKENIZER
|
||||
#include "trie_tokenizer.hpp"
|
||||
#endif
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(
|
||||
[]() { return nullptr; }
|
||||
#ifdef ENABLE_GPT2_TOKENIZER
|
||||
,
|
||||
CustomCpuStruct("GPT2Tokenizer", KernelBpeTokenizer),
|
||||
CustomCpuStruct("CLIPTokenizer", KernelClipBpeTokenizer),
|
||||
CustomCpuStruct("RobertaTokenizer", KernelRobertaBpeTokenizer),
|
||||
CustomCpuStruct("BpeDecoder", KernelBpeDecoder)
|
||||
CustomCpuStruct("BpeDecoder", KernelBpeDecoder),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
,
|
||||
CustomCpuStruct("SentencepieceTokenizer", KernelSentencepieceTokenizer),
|
||||
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder)
|
||||
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRIE_TOKENIZER
|
||||
CustomCpuStruct("TrieTokenizer", KernelTrieTokenizer),
|
||||
CustomCpuStruct("TrieDetokenizer", KernelTrieDetokenizer),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_WORDPIECE_TOKENIZER
|
||||
,
|
||||
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer)
|
||||
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BERT_TOKENIZER
|
||||
,
|
||||
CustomCpuStruct("BasicTokenizer", KernelBasicTokenizer),
|
||||
CustomCpuStruct("BertTokenizer", KernelBertTokenizer),
|
||||
CustomCpuStruct("BertTokenizerDecoder", KernelBertTokenizerDecoder),
|
||||
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer)
|
||||
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BLINGFIRE
|
||||
,
|
||||
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker)
|
||||
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker),
|
||||
#endif
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
[]() { return nullptr; });
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = TokenizerLoader;
|
||||
return op_loader.GetCustomOps();
|
||||
};
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include "ocos.h"
|
||||
#include "narrow.h"
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <charconv>
|
||||
#include <optional>
|
||||
|
||||
#include "unescape.h"
|
||||
|
||||
// This Trie Tree is C++ implementation of
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/rwkv_tokenizer.py
|
||||
// Perf optimized by leveraging C++ features, but the algorithm is the same.
|
||||
class TrieTree {
|
||||
public:
|
||||
static constexpr int kMaxTokenLength_ = 128;
|
||||
|
||||
TrieTree(unsigned char ch = 0) : ch_(ch), to_(256) {}
|
||||
|
||||
void add(const std::string& key, int idx = 0,
|
||||
std::optional<int> value = std::optional<int>()) {
|
||||
if (idx == key.length()) {
|
||||
if (!value) {
|
||||
value = key[0];
|
||||
}
|
||||
value_ = value;
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned char ch = static_cast<unsigned char>(key[idx]);
|
||||
if (to_[ch] == nullptr) {
|
||||
to_[ch] = std::make_unique<TrieTree>(ch);
|
||||
}
|
||||
to_[ch]->add(key, idx + 1, value);
|
||||
}
|
||||
|
||||
int find_longest(const std::string& key, size_t& idx) {
|
||||
const TrieTree* u = this;
|
||||
unsigned char ch = key[idx];
|
||||
|
||||
int tok_id = 0;
|
||||
size_t idx_end = idx;
|
||||
while (u->to_[ch]) {
|
||||
u = u->to_[ch].get();
|
||||
idx += 1;
|
||||
if (u->value_) {
|
||||
tok_id = *u->value_;
|
||||
idx_end = idx;
|
||||
}
|
||||
if (idx == key.length()) {
|
||||
break;
|
||||
}
|
||||
ch = key[idx];
|
||||
}
|
||||
|
||||
idx = idx_end;
|
||||
return tok_id;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<TrieTree>> to_;
|
||||
std::optional<int> value_;
|
||||
unsigned char ch_;
|
||||
};
|
||||
|
||||
class TrieTokenizer {
|
||||
private:
|
||||
std::map<int, std::string> idx2token;
|
||||
TrieTree root;
|
||||
|
||||
public:
|
||||
TrieTokenizer(const std::string& text_tokens) {
|
||||
std::istringstream file(text_tokens);
|
||||
std::string line;
|
||||
|
||||
while (std::getline(file, line)) {
|
||||
auto l_ws = line.find(' ');
|
||||
auto r_ws = line.rfind(' ');
|
||||
if (l_ws == std::string::npos || r_ws == std::string::npos || l_ws == r_ws) {
|
||||
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] vocab line: ", line), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
std::from_chars(line.data(), line.data() + line.size(), idx);
|
||||
if (idx == 0) {
|
||||
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] bad index in vocab line: ", line), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
|
||||
std::string x;
|
||||
int key_length = 0;
|
||||
if (ort_extensions::UnquoteString(raw, x)) {
|
||||
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
|
||||
}
|
||||
if (x.length() != key_length) {
|
||||
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] bad len in vocab line: ", line), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
idx2token[idx] = x;
|
||||
}
|
||||
|
||||
for (const auto& kv : idx2token) {
|
||||
root.add(kv.second, 0, kv.first);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> encodeBytes(const std::string& src) {
|
||||
size_t idx = 0;
|
||||
std::vector<int> tokens;
|
||||
while (idx < src.length()) {
|
||||
auto result = root.find_longest(src, idx);
|
||||
tokens.push_back(result);
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::string decodeBytes(const std::vector<int>& tokens) {
|
||||
std::string result;
|
||||
for (const auto& i : tokens) {
|
||||
result += idx2token[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelTrieTokenizer : public BaseKernel {
|
||||
private:
|
||||
std::shared_ptr<TrieTokenizer> tokenizer;
|
||||
|
||||
public:
|
||||
KernelTrieTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string text_tokens = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
|
||||
};
|
||||
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output) const {
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
const auto& input_dim = input.Shape();
|
||||
|
||||
size_t max_length = 0;
|
||||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
for (auto& str : str_input) {
|
||||
auto tokens = tokenizer->encodeBytes(str);
|
||||
std::vector<int64_t> tokens_int64(tokens.begin(), tokens.end());
|
||||
max_length = std::max(max_length, tokens_int64.size());
|
||||
tokenize_results.emplace_back(tokens_int64);
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dim = input_dim;
|
||||
output_dim.push_back(max_length);
|
||||
auto* token = tokenize_output.Allocate(output_dim);
|
||||
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
token[idx] = id;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
token[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& result : tokenize_results) {
|
||||
result.resize(max_length, 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelTrieDetokenizer : public BaseKernel {
|
||||
private:
|
||||
std::shared_ptr<TrieTokenizer> tokenizer;
|
||||
|
||||
public:
|
||||
KernelTrieDetokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string text_tokens = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
|
||||
};
|
||||
|
||||
void Compute(const ortc::Tensor<int64_t>& tokens, ortc::Tensor<std::string>& text) const {
|
||||
const int64_t* p_ids = tokens.Data();
|
||||
const auto& ids_dim = tokens.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
if (ids_dim.size() > 1) {
|
||||
output_dim.resize(ids_dim.size() - 1);
|
||||
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
|
||||
}
|
||||
|
||||
std::vector<std::string> output(output_dim[0]);
|
||||
for (auto n = 0; n < output_dim[0]; n++) {
|
||||
std::vector<int> ids;
|
||||
for (auto i = 0; i < ids_dim[1]; i++) {
|
||||
ids.push_back(ort_extensions::narrow<int>(p_ids[n * ids_dim[1] + i]));
|
||||
}
|
||||
output[n] = tokenizer->decodeBytes(ids);
|
||||
}
|
||||
|
||||
text.SetStringOutput(output, output_dim);
|
||||
}
|
||||
};
|
|
@ -0,0 +1,190 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
|
||||
inline bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
|
||||
|
||||
inline unsigned int hex_digit_to_int(char c) {
|
||||
unsigned int x = static_cast<unsigned char>(c);
|
||||
if (x > '9') {
|
||||
x += 9;
|
||||
}
|
||||
return x & 0xf;
|
||||
}
|
||||
|
||||
inline bool IsSurrogate(char32_t c) {
|
||||
return c >= 0xD800 && c <= 0xDFFF;
|
||||
}
|
||||
|
||||
size_t EncodeUTF8Char(char* buffer, char32_t utf8_char) {
|
||||
if (utf8_char <= 0x7F) {
|
||||
*buffer = static_cast<char>(utf8_char);
|
||||
return 1;
|
||||
} else if (utf8_char <= 0x7FF) {
|
||||
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[0] = static_cast<char>(0xC0 | utf8_char);
|
||||
return 2;
|
||||
} else if (utf8_char <= 0xFFFF) {
|
||||
buffer[2] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[0] = static_cast<char>(0xE0 | utf8_char);
|
||||
return 3;
|
||||
} else {
|
||||
buffer[3] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[2] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
|
||||
utf8_char >>= 6;
|
||||
buffer[0] = static_cast<char>(0xF0 | utf8_char);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// Unescape a Python escaped string
|
||||
inline bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
|
||||
|
||||
// reserve enough space for the worst case, and final size will be calculated at the end.
|
||||
unescaped.resize(source.length());
|
||||
char* d = unescaped.data();
|
||||
const char* p = source.data();
|
||||
const char* end = p + source.size();
|
||||
const char* last_byte = end - 1;
|
||||
|
||||
while (p == d && p < end && *p != '\\') p++, d++;
|
||||
|
||||
while (p < end) {
|
||||
if (*p != '\\') {
|
||||
*d++ = *p++;
|
||||
} else {
|
||||
if (++p > last_byte) {
|
||||
return false;
|
||||
}
|
||||
switch (*p) {
|
||||
case 'n':
|
||||
*d++ = '\n';
|
||||
break;
|
||||
case 'r':
|
||||
*d++ = '\r';
|
||||
break;
|
||||
case 't':
|
||||
*d++ = '\t';
|
||||
break;
|
||||
break;
|
||||
case '\\':
|
||||
*d++ = '\\';
|
||||
break;
|
||||
case '\'':
|
||||
*d++ = '\'';
|
||||
break;
|
||||
case '"':
|
||||
*d++ = '\"';
|
||||
break;
|
||||
case 'x':
|
||||
case 'X': {
|
||||
if (p >= last_byte) {
|
||||
return false;
|
||||
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
return false;
|
||||
}
|
||||
unsigned int ch = 0;
|
||||
const char* hex_start = p;
|
||||
while (p < last_byte &&
|
||||
IsHexDigit(static_cast<unsigned char>(p[1])))
|
||||
ch = (ch << 4) + hex_digit_to_int(*++p);
|
||||
if (ch > 0xFF && !is_bytes) {
|
||||
return false;
|
||||
}
|
||||
if (is_bytes) {
|
||||
*d++ = static_cast<char>(ch);
|
||||
} else {
|
||||
d += EncodeUTF8Char(d, static_cast<char32_t>(ch));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'u': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 4 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
rune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
case 'U': {
|
||||
char32_t rune = 0;
|
||||
const char* hex_start = p;
|
||||
if (p + 8 >= end) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
|
||||
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
|
||||
if (newrune > 0x10FFFF) {
|
||||
return false;
|
||||
} else {
|
||||
rune = newrune;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (IsSurrogate(rune)) {
|
||||
return false;
|
||||
}
|
||||
d += EncodeUTF8Char(d, rune);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
p++;
|
||||
}
|
||||
}
|
||||
|
||||
unescaped.resize(d - unescaped.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool UnquoteString(const std::string& str, std::string& unquoted) {
|
||||
bool is_bytes = false;
|
||||
int idx_0 = 0;
|
||||
if (str.front() == 'b') {
|
||||
is_bytes = true;
|
||||
idx_0 = 1;
|
||||
}
|
||||
std::string str_view(str.data() + idx_0, str.length() - idx_0);
|
||||
if (str_view.length() < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// unescape the string
|
||||
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -6,3 +6,4 @@ protobuf < 4.0.0
|
|||
onnxruntime >=1.12.0
|
||||
transformers >=4.9.2
|
||||
tensorflow_text >=2.5.0;python_version < '3.11'
|
||||
requests >=2.26.0
|
||||
|
|
14
setup.py
14
setup.py
|
@ -74,13 +74,13 @@ class BuildCMakeExt(_build_ext):
|
|||
project_dir = pathlib.Path().absolute()
|
||||
build_temp = pathlib.Path(self.build_temp)
|
||||
build_temp.mkdir(parents=True, exist_ok=True)
|
||||
ext_fullpath = pathlib.Path(self.get_ext_fullpath(extension.name))
|
||||
ext_fullpath = pathlib.Path(self.get_ext_fullpath(extension.name)).absolute()
|
||||
|
||||
config = 'RelWithDebInfo' if self.debug else 'Release'
|
||||
cmake_args = [
|
||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()),
|
||||
'-DOCOS_BUILD_PYTHON=ON',
|
||||
'-DOCOS_EXTENTION_NAME=' + ext_fullpath.name,
|
||||
'-DOCOS_PYTHON_MODULE_PATH=' + str(ext_fullpath),
|
||||
'-DCMAKE_BUILD_TYPE=' + config
|
||||
]
|
||||
|
||||
|
@ -154,16 +154,6 @@ class BuildCMakeExt(_build_ext):
|
|||
if not self.dry_run:
|
||||
self.spawn([cmake_exe, '--build', str(build_temp)] + build_args)
|
||||
|
||||
if sys.platform == "win32":
|
||||
config_dir = '.'
|
||||
if not (build_temp / 'build.ninja').exists():
|
||||
config_dir = config
|
||||
self.copy_file(build_temp / 'bin' / config_dir / 'extensions_pydll.dll', ext_fullpath,
|
||||
link='hard' if self.debug else None)
|
||||
else:
|
||||
self.copy_file(build_temp / 'lib' / ext_fullpath.name, ext_fullpath,
|
||||
link='sym' if self.debug else None)
|
||||
|
||||
|
||||
class Build(_build):
|
||||
def initialize_options(self) -> None:
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###########################################################################
|
||||
import os
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
from unittest import TestCase, main as unittest_main
|
||||
from onnxruntime_extensions import OrtPyFunction, util
|
||||
|
||||
|
||||
# to avoid to install rwkv LM package, we copy the tokenizer code here.
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
class TRIE:
|
||||
__slots__ = tuple("ch,to,values,front".split(","))
|
||||
to: list
|
||||
values: set
|
||||
|
||||
def __init__(self, front=None, ch=None):
|
||||
self.ch = ch
|
||||
self.to = [None for ch in range(256)]
|
||||
self.values = set()
|
||||
self.front = front
|
||||
|
||||
def __repr__(self):
|
||||
fr = self
|
||||
ret = []
|
||||
while (fr != None):
|
||||
if (fr.ch != None):
|
||||
ret.append(fr.ch)
|
||||
fr = fr.front
|
||||
return "<TRIE %s %s>" % (ret[::-1], self.values)
|
||||
|
||||
def add(self, key: bytes, idx: int = 0, val=None):
|
||||
if (idx == len(key)):
|
||||
if (val is None):
|
||||
val = key
|
||||
self.values.add(val)
|
||||
return self
|
||||
ch = key[idx]
|
||||
if (self.to[ch] is None):
|
||||
self.to[ch] = TRIE(front=self, ch=ch)
|
||||
return self.to[ch].add(key, idx=idx + 1, val=val)
|
||||
|
||||
def find_longest(self, key: bytes, idx: int = 0):
|
||||
u: TRIE = self
|
||||
ch: int = key[idx]
|
||||
|
||||
while (u.to[ch] is not None):
|
||||
u = u.to[ch]
|
||||
idx += 1
|
||||
if (u.values):
|
||||
ret = idx, u, u.values
|
||||
if (idx == len(key)):
|
||||
break
|
||||
ch = key[idx]
|
||||
return ret
|
||||
|
||||
|
||||
class TRIE_TOKENIZER():
|
||||
def __init__(self, file_name):
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k, v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
self.root = TRIE()
|
||||
for t, i in self.token2idx.items():
|
||||
_ = self.root.add(t, val=(t, i))
|
||||
|
||||
def encodeBytes(self, src: bytes):
|
||||
idx: int = 0
|
||||
tokens = []
|
||||
while (idx < len(src)):
|
||||
_idx: int = idx
|
||||
idx, _, values = self.root.find_longest(src, idx)
|
||||
assert (idx != _idx)
|
||||
_, token = next(iter(values))
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
try:
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
except:
|
||||
return '\ufffd' # bad utf-8
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
print()
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class TestTrieTokenizer(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
url = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt"
|
||||
# Create a temporary directory and file path
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
file_name = os.path.basename(url) # Gets the file name from the URL
|
||||
cls.vocab_file = os.path.join(temp_dir, file_name)
|
||||
response = requests.get(url)
|
||||
with open(cls.vocab_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
def test_trie_tokenizer(self):
|
||||
tokr = TRIE_TOKENIZER(self.vocab_file)
|
||||
src = "I love you"
|
||||
tokens = tokr.encode(src)
|
||||
self.assertEqual(tokens, [74, 31337, 22799])
|
||||
self.assertEqual(tokr.decode(tokens), src)
|
||||
|
||||
def test_ort_trie_tokenizer(self):
|
||||
vocab_data = util.read_file(self.vocab_file, 'rb')
|
||||
tokr = OrtPyFunction.from_customop("TrieTokenizer", vocab=vocab_data)
|
||||
tokens = tokr(["I love you"])
|
||||
self.assertEqual(list(tokens[0]), [74, 31337, 22799])
|
||||
|
||||
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data)
|
||||
self.assertEqual(list(detok(tokens)), ["I love you"])
|
||||
|
||||
def test_parity(self):
|
||||
test_sentences = [
|
||||
"I am a girl",
|
||||
"我是个女孩",
|
||||
"私は女の子です",
|
||||
"广东人爱吃云吞面,还有腌面、竹升面,车仔面、油渣面、普宁面线、伊面等各种圆扁粗细,加碱水,不加碱水的面",
|
||||
"我是个人类",
|
||||
"I am a human",
|
||||
"that dog is so cute",
|
||||
"私はねこむすめです、にゃん♪",
|
||||
"宇宙级特大事件!号外号外!"
|
||||
]
|
||||
|
||||
tokr = TRIE_TOKENIZER(self.vocab_file)
|
||||
|
||||
ortx_tokr = OrtPyFunction.from_customop("TrieTokenizer", vocab=util.read_file(self.vocab_file, 'rb'))
|
||||
for s in test_sentences:
|
||||
self.assertEqual(tokr.encode(s), list(ortx_tokr([s])[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest_main()
|
Загрузка…
Ссылка в новой задаче