Add the tokenizer C ABI (#693)
* initial checkins * fix the selectedops build failures * add the tokenization implementation * update the windows DEF file for c abi in cmake file * fix the build on linux * fix some warnings and remove the unused code * initial import of unit tests from tfmtok * add streaming API support * fix the merges loading issues * complete export from tfmtok - needs input id fixing * fix the unit test failures. * fix all unit test failure * refactor streaming code * remove the unused code --------- Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Родитель
1f31d33ed4
Коммит
a8bce4328b
|
@ -564,6 +564,7 @@ standardize_output_folder(ocos_operators)
|
|||
|
||||
target_include_directories(noexcep_operators PUBLIC
|
||||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${GSL_INCLUDE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include/custom_op
|
||||
${PROJECT_SOURCE_DIR}/base
|
||||
|
@ -571,6 +572,7 @@ target_include_directories(noexcep_operators PUBLIC
|
|||
|
||||
target_include_directories(ocos_operators PUBLIC
|
||||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${GSL_INCLUDE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include/custom_op
|
||||
${PROJECT_SOURCE_DIR}/base
|
||||
|
@ -669,7 +671,7 @@ if(OCOS_ENABLE_BLINGFIRE)
|
|||
endif()
|
||||
|
||||
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||
target_include_directories(ocos_operators PRIVATE ${nlohmann_json_SOURCE_DIR}/single_include)
|
||||
target_include_directories(ocos_operators PUBLIC ${nlohmann_json_SOURCE_DIR}/single_include)
|
||||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||
endif()
|
||||
|
||||
|
@ -684,7 +686,6 @@ if(ANDROID)
|
|||
list(APPEND ocos_libraries log)
|
||||
endif()
|
||||
|
||||
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
|
||||
list(APPEND ocos_libraries Microsoft.GSL::GSL)
|
||||
|
||||
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)
|
||||
|
@ -705,6 +706,10 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
|||
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
|
||||
|
||||
file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h")
|
||||
if (NOT OCOS_ENABLE_C_API)
|
||||
file(GLOB shared_TARGET_LIB_C_API_SRC "shared/lib/*tokenizer*")
|
||||
list(REMOVE_ITEM shared_TARGET_LIB_SRC ${shared_TARGET_LIB_C_API_SRC})
|
||||
endif()
|
||||
|
||||
if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
add_executable(ortcustomops ${shared_TARGET_LIB_SRC})
|
||||
|
@ -805,7 +810,12 @@ target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators
|
|||
target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
||||
|
||||
if(_BUILD_SHARED_LIBRARY)
|
||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
|
||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h")
|
||||
if (OCOS_ENABLE_C_API)
|
||||
list(APPEND shared_TARGET_SRC "shared/extensions_c.def")
|
||||
else()
|
||||
list(APPEND shared_TARGET_SRC "shared/ortcustomops.def")
|
||||
endif()
|
||||
add_library(extensions_shared SHARED ${shared_TARGET_SRC})
|
||||
|
||||
# We need to propagate OCOS_SHARED_LIBRARY if set.
|
||||
|
|
|
@ -12,12 +12,6 @@ struct OrtxStatus::Rep {
|
|||
OrtxStatus::OrtxStatus() = default;
|
||||
OrtxStatus::~OrtxStatus() = default;
|
||||
|
||||
OrtxStatus::OrtxStatus(extError_t code, std::string_view error_message)
|
||||
: rep_(new Rep) {
|
||||
rep_->code = code;
|
||||
rep_->error_message = std::string(error_message);
|
||||
}
|
||||
|
||||
OrtxStatus::OrtxStatus(extError_t code, const std::string& error_message)
|
||||
: rep_(new Rep) {
|
||||
rep_->code = code;
|
||||
|
@ -58,3 +52,51 @@ OrtStatus* OrtxStatus::CreateOrtStatus() const {
|
|||
OrtStatus* status = OrtW::CreateStatus(Message(), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string OrtxStatus::ToString() const {
|
||||
if (rep_ == nullptr)
|
||||
return "OK";
|
||||
|
||||
std::string result;
|
||||
switch (Code()) {
|
||||
case extError_t::kOrtxOK:
|
||||
result = "Success";
|
||||
break;
|
||||
case extError_t::kOrtxErrorInvalidArgument:
|
||||
result = "Invalid argument";
|
||||
break;
|
||||
case extError_t::kOrtxErrorOutOfMemory:
|
||||
result = "Out of Memory";
|
||||
break;
|
||||
case extError_t::kOrtxErrorCorruptData:
|
||||
result = "Corrupt data";
|
||||
break;
|
||||
case extError_t::kOrtxErrorInvalidFile:
|
||||
result = "Invalid data file";
|
||||
break;
|
||||
case extError_t::kOrtxErrorNotFound:
|
||||
result = "Not found";
|
||||
break;
|
||||
case extError_t::kOrtxErrorAlreadyExists:
|
||||
result = "Already exists";
|
||||
break;
|
||||
case extError_t::kOrtxErrorOutOfRange:
|
||||
result = "Out of range";
|
||||
break;
|
||||
case extError_t::kOrtxErrorNotImplemented:
|
||||
result = "Not implemented";
|
||||
break;
|
||||
case extError_t::kOrtxErrorInternal:
|
||||
result = "Internal";
|
||||
break;
|
||||
case extError_t::kOrtxErrorUnknown:
|
||||
result = "Unknown";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
result += ": ";
|
||||
result += rep_->error_message;
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ struct OrtStatus;
|
|||
struct OrtxStatus {
|
||||
OrtxStatus();
|
||||
~OrtxStatus();
|
||||
OrtxStatus(extError_t code, std::string_view error_message);
|
||||
OrtxStatus(extError_t code, const std::string& error_message);
|
||||
OrtxStatus(const OrtxStatus& s);
|
||||
OrtxStatus& operator=(const OrtxStatus& s);
|
||||
|
@ -22,6 +21,7 @@ struct OrtxStatus {
|
|||
void SetErrorMessage(const char* str);
|
||||
[[nodiscard]] const char* Message() const;
|
||||
[[nodiscard]] extError_t Code() const;
|
||||
std::string ToString() const;
|
||||
|
||||
OrtStatus* CreateOrtStatus() const;
|
||||
|
||||
|
|
|
@ -58,6 +58,24 @@ class ustring : public std::u32string {
|
|||
return std::string(utf8_buf);
|
||||
}
|
||||
|
||||
static size_t UTF8Len(char byte1) {
|
||||
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
|
||||
uint8_t highbits = static_cast<uint8_t>(byte1) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
static size_t UTF8Len(char32_t codepoint) {
|
||||
if (codepoint <= 0x7F) {
|
||||
return 1;
|
||||
} else if (codepoint <= 0x7FF) {
|
||||
return 2;
|
||||
} else if (codepoint <= 0xFFFF) {
|
||||
return 3;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ValidateUTF8(const std::string& data) {
|
||||
int cnt = 0;
|
||||
for (auto i = 0; i < data.size(); i++) {
|
||||
|
|
|
@ -199,19 +199,22 @@ else()
|
|||
endif()
|
||||
endblock()
|
||||
|
||||
block()
|
||||
file(GLOB tokenizer_TEST_SRC
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.hpp")
|
||||
if (OCOS_ENABLE_C_API)
|
||||
file(GLOB tokenizer_TEST_SRC
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.c"
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.h")
|
||||
|
||||
add_test_target(TARGET tokenizer_api_test
|
||||
TEST_SOURCES ${tokenizer_TEST_SRC}
|
||||
LIBRARIES onnxruntime_extensions ${ocos_libraries}
|
||||
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
|
||||
add_test_target(TARGET tokenizer_api_test
|
||||
TEST_SOURCES ${tokenizer_TEST_SRC}
|
||||
LIBRARIES onnxruntime_extensions ${ocos_libraries}
|
||||
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
|
||||
|
||||
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
endblock()
|
||||
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
target_include_directories(tokenizer_api_test PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/
|
||||
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
endif()
|
|
@ -390,11 +390,14 @@ STATIC_INLINE uint32_t Mur(uint32_t a, uint32_t h) {
|
|||
|
||||
template <typename T> STATIC_INLINE T DebugTweak(T x) {
|
||||
if (debug_mode) {
|
||||
if (sizeof(x) == 4) {
|
||||
x = ~Bswap32(x * c1);
|
||||
} else {
|
||||
x = ~Bswap64(x * k1);
|
||||
}
|
||||
x = ~Bswap32(x * c1);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template <> uint64_t DebugTweak(uint64_t x) {
|
||||
if (debug_mode) {
|
||||
x = ~Bswap64(x * k1);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
@ -461,7 +464,7 @@ STATIC_INLINE uint64_t HashLen0to16(const char *s, size_t len) {
|
|||
uint8_t b = s[len >> 1];
|
||||
uint8_t c = s[len - 1];
|
||||
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
|
||||
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
|
||||
uint32_t z = static_cast<uint32_t>(len + (static_cast<uint32_t>(c) << 2));
|
||||
return ShiftMix(y * k2 ^ z * k0) * k2;
|
||||
}
|
||||
return k2;
|
||||
|
@ -1002,7 +1005,7 @@ namespace farmhashnt {
|
|||
|
||||
uint32_t Hash32(const char *s, size_t len) {
|
||||
FARMHASH_DIE_IF_MISCONFIGURED;
|
||||
return s == NULL ? 0 : len;
|
||||
return s == NULL ? 0 : static_cast<uint32_t>(len);
|
||||
}
|
||||
|
||||
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
|
||||
|
@ -1039,7 +1042,7 @@ STATIC_INLINE uint32_t Hash32Len13to24(const char *s, size_t len, uint32_t seed
|
|||
uint32_t d = Fetch(s + (len >> 1));
|
||||
uint32_t e = Fetch(s);
|
||||
uint32_t f = Fetch(s + len - 4);
|
||||
uint32_t h = d * c1 + len + seed;
|
||||
uint32_t h = static_cast<uint32_t>(d * c1 + len + seed);
|
||||
a = Rotate(a, 12) + f;
|
||||
h = Mur(c, h) + a;
|
||||
a = Rotate(a, 3) + c;
|
||||
|
@ -1057,11 +1060,11 @@ STATIC_INLINE uint32_t Hash32Len0to4(const char *s, size_t len, uint32_t seed =
|
|||
b = b * c1 + v;
|
||||
c ^= b;
|
||||
}
|
||||
return fmix(Mur(b, Mur(len, c)));
|
||||
return fmix(Mur(b, Mur(static_cast<uint32_t>(len), c)));
|
||||
}
|
||||
|
||||
STATIC_INLINE uint32_t Hash32Len5to12(const char *s, size_t len, uint32_t seed = 0) {
|
||||
uint32_t a = len, b = len * 5, c = 9, d = b + seed;
|
||||
uint32_t a = static_cast<uint32_t>(len), b = static_cast<uint32_t>(len * 5), c = 9, d = b + seed;
|
||||
a += Fetch(s);
|
||||
b += Fetch(s + len - 4);
|
||||
c += Fetch(s + ((len >> 1) & 4));
|
||||
|
@ -1076,7 +1079,7 @@ uint32_t Hash32(const char *s, size_t len) {
|
|||
}
|
||||
|
||||
// len > 24
|
||||
uint32_t h = len, g = c1 * len, f = g;
|
||||
uint32_t h = static_cast<uint32_t>(len), g = static_cast<uint32_t>(c1 * len), f = g;
|
||||
uint32_t a0 = Rotate(Fetch(s + len - 4) * c1, 17) * c2;
|
||||
uint32_t a1 = Rotate(Fetch(s + len - 8) * c1, 17) * c2;
|
||||
uint32_t a2 = Rotate(Fetch(s + len - 16) * c1, 17) * c2;
|
||||
|
@ -1132,7 +1135,7 @@ uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
|
|||
else if (len >= 5) return Hash32Len5to12(s, len, seed);
|
||||
else return Hash32Len0to4(s, len, seed);
|
||||
}
|
||||
uint32_t h = Hash32Len13to24(s, 24, seed ^ len);
|
||||
uint32_t h = Hash32Len13to24(s, 24, seed ^ static_cast<uint32_t>(len));
|
||||
return Mur(Hash32(s + 24, len - 24) + seed, h);
|
||||
}
|
||||
} // namespace farmhashmk
|
||||
|
@ -1141,7 +1144,7 @@ namespace farmhashsu {
|
|||
|
||||
uint32_t Hash32(const char *s, size_t len) {
|
||||
FARMHASH_DIE_IF_MISCONFIGURED;
|
||||
return s == NULL ? 0 : len;
|
||||
return s == NULL ? 0 : static_cast<uint32_t>(len);
|
||||
}
|
||||
|
||||
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
|
||||
|
@ -1361,7 +1364,7 @@ namespace farmhashsa {
|
|||
|
||||
uint32_t Hash32(const char *s, size_t len) {
|
||||
FARMHASH_DIE_IF_MISCONFIGURED;
|
||||
return s == NULL ? 0 : len;
|
||||
return s == NULL ? 0 : static_cast<uint32_t>(len);
|
||||
}
|
||||
|
||||
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
|
||||
|
@ -1587,7 +1590,7 @@ STATIC_INLINE uint32_t Hash32Len13to24(const char *s, size_t len) {
|
|||
uint32_t d = Fetch(s + (len >> 1));
|
||||
uint32_t e = Fetch(s);
|
||||
uint32_t f = Fetch(s + len - 4);
|
||||
uint32_t h = len;
|
||||
uint32_t h = static_cast<uint32_t>(len);
|
||||
|
||||
return fmix(Mur(f, Mur(e, Mur(d, Mur(c, Mur(b, Mur(a, h)))))));
|
||||
}
|
||||
|
@ -1600,11 +1603,11 @@ STATIC_INLINE uint32_t Hash32Len0to4(const char *s, size_t len) {
|
|||
b = b * c1 + v;
|
||||
c ^= b;
|
||||
}
|
||||
return fmix(Mur(b, Mur(len, c)));
|
||||
return fmix(Mur(b, Mur(static_cast<uint32_t>(len), c)));
|
||||
}
|
||||
|
||||
STATIC_INLINE uint32_t Hash32Len5to12(const char *s, size_t len) {
|
||||
uint32_t a = len, b = len * 5, c = 9, d = b;
|
||||
uint32_t a = static_cast<uint32_t>(len), b = static_cast<uint32_t>(len * 5), c = 9, d = b;
|
||||
a += Fetch(s);
|
||||
b += Fetch(s + len - 4);
|
||||
c += Fetch(s + ((len >> 1) & 4));
|
||||
|
@ -1619,7 +1622,7 @@ uint32_t Hash32(const char *s, size_t len) {
|
|||
}
|
||||
|
||||
// len > 24
|
||||
uint32_t h = len, g = c1 * len, f = g;
|
||||
uint32_t h = static_cast<uint32_t>(len), g = static_cast<uint32_t>(c1 * len), f = g;
|
||||
uint32_t a0 = Rotate(Fetch(s + len - 4) * c1, 17) * c2;
|
||||
uint32_t a1 = Rotate(Fetch(s + len - 8) * c1, 17) * c2;
|
||||
uint32_t a2 = Rotate(Fetch(s + len - 16) * c1, 17) * c2;
|
||||
|
@ -1686,7 +1689,7 @@ uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
|
|||
else if (len >= 5) return farmhashmk::Hash32Len5to12(s, len, seed);
|
||||
else return farmhashmk::Hash32Len0to4(s, len, seed);
|
||||
}
|
||||
uint32_t h = farmhashmk::Hash32Len13to24(s, 24, seed ^ len);
|
||||
uint32_t h = farmhashmk::Hash32Len13to24(s, 24, seed ^ static_cast<uint32_t>(len));
|
||||
return Mur(Hash32(s + 24, len - 24) + seed, h);
|
||||
}
|
||||
|
||||
|
@ -1736,7 +1739,7 @@ STATIC_INLINE uint64_t HashLen0to16(const char *s, size_t len) {
|
|||
uint8_t b = s[len >> 1];
|
||||
uint8_t c = s[len - 1];
|
||||
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
|
||||
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
|
||||
uint32_t z = static_cast<uint32_t>(len + (static_cast<uint32_t>(c) << 2));
|
||||
return ShiftMix(y * k2 ^ z * k0) * k2;
|
||||
}
|
||||
return k2;
|
||||
|
@ -1775,7 +1778,7 @@ STATIC_INLINE NAMESPACE_FOR_HASH_FUNCTIONS::uint128_t CityMurmur(const char *s,
|
|||
uint64_t b = Uint128High64(seed);
|
||||
uint64_t c = 0;
|
||||
uint64_t d = 0;
|
||||
signed long l = len - 16;
|
||||
signed long l = static_cast<signed long>(len) - 16;
|
||||
if (l <= 0) { // len <= 16
|
||||
a = ShiftMix(a * k1) * k1;
|
||||
c = b * k1 + HashLen0to16(s, len);
|
||||
|
|
|
@ -78,7 +78,6 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
|
|||
*/
|
||||
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
|
||||
|
||||
|
||||
/** \brief Create a tokenizer object with the specified tokenizer path
|
||||
*
|
||||
* \param tokenizer Pointer to store the created tokenizer object
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
struct KernelBpeDecoder {
|
||||
public:
|
||||
virtual ~KernelBpeDecoder() = default;
|
||||
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
|
||||
std::string vocab;
|
||||
|
@ -96,27 +98,26 @@ struct KernelBpeDecoder {
|
|||
void BuildIdVocab(const std::string& vocab) {
|
||||
arr_vocab_.reserve(vocab.size() / 2); // give a rough estimation.
|
||||
|
||||
std::u32string u_vocab = ustring(vocab);
|
||||
std::u32string_view uv_vocab(u_vocab);
|
||||
std::string_view v_vocab(vocab);
|
||||
size_t last_pos = 0;
|
||||
|
||||
auto ccount = uv_vocab.size();
|
||||
for (size_t n = 0; n < ccount; ++n) {
|
||||
if (uv_vocab[n] == char32_t('\n')) {
|
||||
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos);
|
||||
arr_vocab_.emplace_back(ustring(s_tok));
|
||||
auto c_count = v_vocab.size();
|
||||
for (size_t n = 0; n < c_count; ++n) {
|
||||
if (v_vocab[n] == '\n') {
|
||||
std::string_view s_tok = v_vocab.substr(last_pos, n - last_pos);
|
||||
arr_vocab_.emplace_back(std::string(s_tok));
|
||||
last_pos = n + 1;
|
||||
} else if (n == ccount - 1) {
|
||||
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos + 1);
|
||||
arr_vocab_.emplace_back(ustring(s_tok));
|
||||
} else if (n == c_count - 1) {
|
||||
std::string_view s_tok = v_vocab.substr(last_pos, n - last_pos + 1);
|
||||
arr_vocab_.emplace_back(std::string(s_tok));
|
||||
}
|
||||
}
|
||||
|
||||
arr_vocab_.shrink_to_fit();
|
||||
}
|
||||
|
||||
OrtStatusPtr Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
|
@ -146,10 +147,10 @@ struct KernelBpeDecoder {
|
|||
}
|
||||
|
||||
if (added_tokens_.count(token)) {
|
||||
const std::string ws = added_tokens_.at(token);
|
||||
const std::string& ws = added_tokens_.at(token);
|
||||
decoded_token = (std::string)ws;
|
||||
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
|
||||
const auto str = arr_vocab_[token];
|
||||
const auto str = ustring(arr_vocab_[token]);
|
||||
for (auto wchr : str) {
|
||||
if (byte_decoder_.count(wchr) == 0) {
|
||||
if (wchr <= char32_t(0xFF)) {
|
||||
|
@ -173,6 +174,14 @@ struct KernelBpeDecoder {
|
|||
decoded_token = unk_token_;
|
||||
}
|
||||
}
|
||||
// remove the end_of_word_suffix like </w> or </s> etc.
|
||||
if (end_of_word_suffix_.size() > 0) {
|
||||
if (decoded_token.size() >= end_of_word_suffix_.size() &&
|
||||
decoded_token.substr(decoded_token.size() - end_of_word_suffix_.size()) == end_of_word_suffix_) {
|
||||
decoded_token = decoded_token.substr(0, decoded_token.size() - end_of_word_suffix_.size());
|
||||
decoded_token += ' ';
|
||||
}
|
||||
}
|
||||
|
||||
if (whitespace_token_ &&
|
||||
f_special && (tok_idx > 0 && !f_special_last)) {
|
||||
|
@ -193,10 +202,10 @@ struct KernelBpeDecoder {
|
|||
p_ids += seq_len;
|
||||
}
|
||||
output.SetStringOutput(decoded_strings, output_dim);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::string bos_token_{"<|endoftext|>"};
|
||||
std::string eos_token_{"<|endoftext|>"};
|
||||
std::string unk_token_{"<|endoftext|>"};
|
||||
|
@ -206,8 +215,9 @@ struct KernelBpeDecoder {
|
|||
int64_t en_normalization_ = 0;
|
||||
int64_t skip_special_tokens_ = 0;
|
||||
int64_t whitespace_token_ = 0;
|
||||
std::vector<ustring> arr_vocab_;
|
||||
std::vector<std::string> arr_vocab_;
|
||||
std::map<char32_t, unsigned char> byte_decoder_;
|
||||
std::map<int64_t, std::string> added_tokens_;
|
||||
std::set<int64_t> all_special_ids_;
|
||||
std::string end_of_word_suffix_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include "ocos.h"
|
||||
#include "status.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "bpe_types.h"
|
||||
|
||||
namespace ort_extensions::bpe {
|
||||
|
||||
class TokenJsonConfig final {
|
||||
public:
|
||||
TokenJsonConfig() {}
|
||||
~TokenJsonConfig() {}
|
||||
using json = nlohmann::json;
|
||||
using json_pointer = nlohmann::json_pointer<std::string>;
|
||||
|
||||
public:
|
||||
OrtxStatus Load(const std::string& json_path) {
|
||||
if (json_path.empty()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
||||
}
|
||||
|
||||
auto file_path = std::filesystem::path(json_path) / "tokenizer_config.json";
|
||||
std::ifstream ifs(file_path);
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
|
||||
}
|
||||
|
||||
vocab_path_ = (std::filesystem::path(json_path) / "tokenizer.json").string();
|
||||
nlohmann::json json_config = nlohmann::json::parse(ifs);
|
||||
add_bos_token_ = json_config.value("add_bos_token", false);
|
||||
add_eos_token_ = json_config.value("add_eos_token", false);
|
||||
clean_up_tokenization_spaces_ = json_config.value("clean_up_tokenization_spaces", false);
|
||||
model_max_length_ = json_config.value("model_max_length", 1e+30);
|
||||
|
||||
tokenizer_class_ = json_config.value("tokenizer_class", "");
|
||||
|
||||
auto tok_iter = json_config.find("bos_token");
|
||||
if (tok_iter != json_config.end() && tok_iter->is_object()) {
|
||||
bos_token_ = tok_iter->value("content", "");
|
||||
eos_token_ = json_config.value("/eos_token/content"_json_pointer, "");
|
||||
unk_token_ = json_config.value("/unk_token/content"_json_pointer, "");
|
||||
} else {
|
||||
bos_token_ = json_config.value("bos_token", "");
|
||||
eos_token_ = json_config.value("eos_token", "");
|
||||
unk_token_ = json_config.value("unk_token", "");
|
||||
}
|
||||
|
||||
auto pad_iter = json_config.find("pad_token");
|
||||
if (pad_iter != json_config.end() && pad_iter->is_string()) {
|
||||
pad_token_ = json_config.value("pad_token", "");
|
||||
}
|
||||
|
||||
if (tokenizer_class_.empty()) {
|
||||
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get tokenizer class.");
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
const std::string& GetVocabDataFile() const {
|
||||
return vocab_path_;
|
||||
}
|
||||
|
||||
public:
|
||||
bool add_bos_token_{};
|
||||
bool add_eos_token_{};
|
||||
bool clean_up_tokenization_spaces_{};
|
||||
double model_max_length_{};
|
||||
|
||||
std::string tokenizer_class_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::string unk_token_;
|
||||
std::string pad_token_;
|
||||
|
||||
private:
|
||||
std::string vocab_path_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions::bpe
|
|
@ -1,15 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "bpe_tokenizer.hpp"
|
||||
#include "bpe_kernels.h"
|
||||
#include "ortx_common.h"
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_json.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
#include <optional>
|
||||
#include <limits>
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
const char kModel_Default[] = "PreTrained";
|
||||
const char kModel_GPT2[] = "GPT2";
|
||||
const char kModel_CodeGen[] = "CodeGen";
|
||||
const char kModel_Roberta[] = "Roberta";
|
||||
|
@ -43,40 +45,39 @@ std::string BpeModelConf::GetSpecialTokens() const {
|
|||
}
|
||||
|
||||
// Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
|
||||
bool IsUnicodeSpace(char32_t ch) {
|
||||
switch (ch) {
|
||||
case 0x0009:
|
||||
case 0x000A:
|
||||
case 0x000B:
|
||||
case 0x000C:
|
||||
case 0x000D:
|
||||
case 0x001C:
|
||||
case 0x001D:
|
||||
case 0x001E:
|
||||
case 0x001F:
|
||||
case 0x0020:
|
||||
case 0x0085:
|
||||
case 0x00A0:
|
||||
case 0x1680:
|
||||
case 0x2000:
|
||||
case 0x2001:
|
||||
case 0x2002:
|
||||
case 0x2003:
|
||||
case 0x2004:
|
||||
case 0x2005:
|
||||
case 0x2006:
|
||||
case 0x2007:
|
||||
case 0x2008:
|
||||
case 0x2009:
|
||||
case 0x200A:
|
||||
case 0x2028:
|
||||
case 0x2029:
|
||||
case 0x202F:
|
||||
case 0x205F:
|
||||
case 0x3000:
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
static bool IsUnicodeSpace(char32_t ch) {
|
||||
const std::set<char32_t> unicode_spaces = {
|
||||
0x0009, // CHARACTER TABULATION
|
||||
0x000A, // LINE FEED (LF)
|
||||
0x000B, // LINE TABULATION
|
||||
0x000C, // FORM FEED (FF)
|
||||
0x000D, // CARRIAGE RETURN (CR)
|
||||
0x001C, // FILE SEPARATOR
|
||||
0x001D, // GROUP SEPARATOR
|
||||
0x001E, // RECORD SEPARATOR
|
||||
0x001F, // UNIT SEPARATOR
|
||||
0x0020, // SPACE
|
||||
0x0085, // NEXT
|
||||
0x00A0, // NO-BREAK SPACE
|
||||
0x1680, // OGHAM SPACE MARK
|
||||
0x2000, // EN QUAD
|
||||
0x2001, // EM QUAD
|
||||
0x2002, // EN SPACE
|
||||
0x2003, // EM SPACE
|
||||
0x2004, // THREE-PER-EM SPACE
|
||||
0x2005, // FOUR-PER-EM SPACE
|
||||
0x2006, // SIX-PER-EM SPACE
|
||||
0x2007, // FIGURE SPACE
|
||||
0x2008, // PUNCTUATION SPACE
|
||||
0x2009, // THIN SPACE
|
||||
0x200A, // HAIR SPACE
|
||||
0x2028, // LINE SEPARATOR
|
||||
0x2029, // PARAGRAPH SEPARATOR
|
||||
0x202F, // NARROW NO-BREAK SPACE
|
||||
0x205F, // MEDIUM MATHEMATICAL SPACE
|
||||
0x3000, // IDEOGRAPHIC SPACE
|
||||
};
|
||||
return unicode_spaces.count(ch) > 0;
|
||||
}
|
||||
|
||||
bool AllSpaceUstring(const ustring& str) {
|
||||
|
@ -105,7 +106,7 @@ ustring RemoveConsecutiveSpaces(const ustring& input) {
|
|||
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
|
||||
: bpe_conf_(conf) {
|
||||
model_name_ = conf.name_;
|
||||
model_name_ = conf.name_ == nullptr ? "" : conf.name_;
|
||||
};
|
||||
|
||||
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
|
@ -133,13 +134,13 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
|
|||
model_name_ = model_name;
|
||||
}
|
||||
|
||||
std::stringstream vocabu_stream(vocab);
|
||||
std::stringstream vocab_stream(vocab);
|
||||
std::stringstream merges_stream(merges);
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
auto status = bbpe_tokenizer_->Load(vocabu_stream,
|
||||
auto status = bbpe_tokenizer_->Load(vocab_stream,
|
||||
merges_stream,
|
||||
bpe_conf_.unk_token_,
|
||||
bpe_conf_.GetSpecialTokens().c_str(),
|
||||
bpe_conf_.get().unk_token_,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
IsSpmModel(ModelName()));
|
||||
if (!status.IsOk()) {
|
||||
return status.CreateOrtStatus();
|
||||
|
@ -153,15 +154,14 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
|
|||
}
|
||||
|
||||
// TODO: need to check if the special token ids are the same as the ones in HFTokenizer
|
||||
unk_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.unk_token_);
|
||||
if (bpe_conf_.bos_token_ != nullptr) {
|
||||
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.bos_token_);
|
||||
if (bpe_conf_.get().bos_token_ != nullptr) {
|
||||
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().bos_token_);
|
||||
}
|
||||
if (bpe_conf_.eos_token_ != nullptr) {
|
||||
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.eos_token_);
|
||||
if (bpe_conf_.get().eos_token_ != nullptr) {
|
||||
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().eos_token_);
|
||||
}
|
||||
if (bpe_conf_.pad_token_ != nullptr) {
|
||||
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.pad_token_);
|
||||
if (bpe_conf_.get().pad_token_ != nullptr) {
|
||||
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.get().pad_token_);
|
||||
}
|
||||
|
||||
return {};
|
||||
|
@ -202,10 +202,23 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
return res;
|
||||
}
|
||||
|
||||
if (IsBosEosRequired(ModelName())) {
|
||||
// Add BOS token to result
|
||||
bool add_bos_token = false;
|
||||
if (add_bos_token_.has_value()) {
|
||||
add_bos_token = add_bos_token_.value();
|
||||
} else if (IsBosEosRequired(ModelName())) {
|
||||
add_bos_token = true;
|
||||
}
|
||||
if (add_bos_token) {
|
||||
res.push_back(bos_token_id_);
|
||||
}
|
||||
|
||||
bool add_eos_token = false;
|
||||
if (add_eos_token_.has_value()) {
|
||||
add_eos_token = add_eos_token_.value();
|
||||
} else if (IsBosEosRequired(ModelName())) {
|
||||
add_eos_token = true;
|
||||
}
|
||||
|
||||
if (ModelName() == kModel_CLIP) {
|
||||
// Convert to lowercase
|
||||
std::transform(input.begin(), input.end(), input.begin(), [](char32_t c) { return static_cast<char32_t>(ToLower(c)); });
|
||||
|
@ -231,7 +244,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
OffsetMappingType offset_mapping;
|
||||
|
||||
if (compute_offset_mapping) {
|
||||
if (IsBosEosRequired(ModelName())) {
|
||||
if (add_bos_token) {
|
||||
// Add offset mapping for BOS token
|
||||
offset_mapping.push_back(std::make_pair(0, 0));
|
||||
}
|
||||
|
@ -300,7 +313,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
}
|
||||
|
||||
if (compute_offset_mapping) {
|
||||
if (IsBosEosRequired(ModelName())) {
|
||||
if (add_eos_token) {
|
||||
// Add offset mapping for EOS token
|
||||
offset_mapping.emplace_back(std::make_pair(0, 0));
|
||||
}
|
||||
|
@ -309,7 +322,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
}
|
||||
}
|
||||
|
||||
if (IsBosEosRequired(ModelName())) {
|
||||
if (add_eos_token) {
|
||||
// Add EOS token to result
|
||||
res.push_back(eos_token_id_);
|
||||
}
|
||||
|
@ -529,3 +542,95 @@ static const auto kSpmConfiguration = BpeModelConf{
|
|||
|
||||
SpmTokenizer::SpmTokenizer()
|
||||
: KernelBpeTokenizer(kSpmConfiguration) {}
|
||||
|
||||
JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}
|
||||
|
||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs(voc_file);
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
|
||||
}
|
||||
|
||||
const char token_sub[] = "Tokenizer";
|
||||
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
|
||||
json_conf_.name_ = model_name_.c_str();
|
||||
json_conf_.bos_token_ = config.bos_token_.c_str();
|
||||
json_conf_.eos_token_ = config.eos_token_.c_str();
|
||||
json_conf_.unk_token_ = config.unk_token_.c_str();
|
||||
json_conf_.pad_token_ = config.pad_token_.c_str();
|
||||
|
||||
// re-bind the configuration object
|
||||
bpe_conf_ = json_conf_;
|
||||
|
||||
// consider to use SAX parser for large json file
|
||||
nlohmann::json tok_json;
|
||||
ifs >> tok_json;
|
||||
auto model_node = tok_json.find("model");
|
||||
if (model_node == tok_json.end()) {
|
||||
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
|
||||
}
|
||||
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
auto status = bbpe_tokenizer_->Load(*model_node,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
IsSpmModel(ModelName()));
|
||||
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
for (const auto& token : *added_tokens) {
|
||||
bpe::AddedToken added_token;
|
||||
added_token.id_ = token.value("id", 0);
|
||||
added_token.token_type_ = token.value("__type", "");
|
||||
added_token.content_ = token.value("content", "");
|
||||
added_token.lstrip_ = token.value("lstrip", false);
|
||||
added_token.normalized_ = token.value("normalized", false);
|
||||
added_token.rstrip_ = token.value("rstrip", false);
|
||||
added_token.single_word_ = token.value("single_word", false);
|
||||
added_token.special_ = token.value("special", false);
|
||||
|
||||
added_tokens_.emplace_back(added_token);
|
||||
if (added_token.content_ == config.bos_token_) {
|
||||
bos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.eos_token_) {
|
||||
eos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.pad_token_) {
|
||||
pad_token_id_ = added_token.id_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
add_bos_token_ = config.add_bos_token_;
|
||||
add_eos_token_ = config.add_eos_token_;
|
||||
// add_bos_token is default as false, we need to check post_processor json to see if it is true
|
||||
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
|
||||
auto post_processor = tok_json.find("post_processor");
|
||||
if (post_processor != tok_json.end()) {
|
||||
std::string text = post_processor->dump();
|
||||
if (text.find(config.bos_token_) != std::string::npos) {
|
||||
add_bos_token_ = true;
|
||||
}
|
||||
if (text.find(config.eos_token_) != std::string::npos) {
|
||||
add_eos_token_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
|
|
|
@ -7,12 +7,15 @@
|
|||
#include "status.h"
|
||||
#include "ustring.h"
|
||||
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <functional>
|
||||
|
||||
#include "bpe_types.h"
|
||||
|
||||
struct BpeModelConf {
|
||||
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.
|
||||
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.
|
||||
const char* unk_token_{"<|endoftext|>"};
|
||||
const char* bos_token_{"<|endoftext|>"};
|
||||
const char* eos_token_{"<|endoftext|>"};
|
||||
|
@ -21,18 +24,14 @@ struct BpeModelConf {
|
|||
std::string GetSpecialTokens() const;
|
||||
};
|
||||
|
||||
namespace ort_extensions {
|
||||
class BpeModel;
|
||||
}
|
||||
|
||||
struct KernelBpeTokenizer {
|
||||
KernelBpeTokenizer(const BpeModelConf& conf);
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
const std::string& ModelName() const { return model_name_; }
|
||||
|
||||
|
@ -48,25 +47,27 @@ struct KernelBpeTokenizer {
|
|||
bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const;
|
||||
|
||||
private:
|
||||
const BpeModelConf& bpe_conf_;
|
||||
protected:
|
||||
std::reference_wrapper<BpeModelConf const> bpe_conf_;
|
||||
std::string model_name_;
|
||||
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;
|
||||
|
||||
int64_t padding_length_ = -1;
|
||||
uint32_t unk_token_id_{};
|
||||
uint32_t bos_token_id_{};
|
||||
uint32_t eos_token_id_{};
|
||||
uint32_t pad_token_id_{};
|
||||
|
||||
std::optional<bool> add_bos_token_;
|
||||
std::optional<bool> add_eos_token_;
|
||||
};
|
||||
|
||||
struct GPT2Tokenizer : KernelBpeTokenizer {
|
||||
GPT2Tokenizer();
|
||||
// required by LiteCustomOp which needs an explicit Compute declaration for non-MSVC compiler.
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
@ -75,9 +76,9 @@ struct RobertaTokenizer : KernelBpeTokenizer {
|
|||
RobertaTokenizer();
|
||||
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
@ -86,9 +87,9 @@ struct CLIPTokenizer : KernelBpeTokenizer {
|
|||
CLIPTokenizer();
|
||||
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
@ -97,9 +98,27 @@ struct SpmTokenizer : KernelBpeTokenizer {
|
|||
SpmTokenizer();
|
||||
// required by LiteCustomOp which needs a explicit Compute declaration for non-MSVC compiler.
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
||||
class JsonFastTokenizer : KernelBpeTokenizer {
|
||||
public:
|
||||
JsonFastTokenizer();
|
||||
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
public:
|
||||
auto GetAddedTokens() const { return added_tokens_; }
|
||||
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
|
||||
|
||||
private:
|
||||
BpeModelConf json_conf_;
|
||||
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_decoder.hpp"
|
||||
#include "bpe_json.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
struct BPEDecoderState {
|
||||
bool f_special_last{};
|
||||
std::string incomplete_utf8_;
|
||||
};
|
||||
} // namespace ort_extensions
|
||||
|
||||
class BpeStreamingDecoder : public KernelBpeDecoder {
|
||||
public:
|
||||
BpeStreamingDecoder() = default;
|
||||
~BpeStreamingDecoder() override = default;
|
||||
|
||||
using BPEDecoderState = ort_extensions::BPEDecoderState;
|
||||
|
||||
// shared the data between the encoder and decoder
|
||||
OrtxStatus Load(
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
|
||||
const JsonFastTokenizer& encoder) {
|
||||
const auto& tok_config = *ptr_config;
|
||||
bos_token_ = tok_config.bos_token_;
|
||||
eos_token_ = tok_config.eos_token_;
|
||||
unk_token_ = tok_config.unk_token_;
|
||||
|
||||
auto& a_toks = encoder.GetAddedTokens();
|
||||
for (const auto& tok : a_toks) {
|
||||
added_tokens_[tok.id_] = tok.content_;
|
||||
if (tok.special_) {
|
||||
all_special_ids_.insert(tok.id_);
|
||||
}
|
||||
}
|
||||
|
||||
auto& tok_model = encoder.GetEncoder();
|
||||
CreateByteDecoder(tok_model);
|
||||
arr_vocab_ = tok_model.BuildDecoder();
|
||||
end_of_word_suffix_ = tok_model.GetEndOfWordSuffix();
|
||||
// whitespace_token_ = tok_config.clean_up_tokenization_spaces_ ? 1 : 0;
|
||||
skip_special_tokens_ = 1;
|
||||
// en_normalization_ = 0;
|
||||
add_dummy_prefix_ = tok_config.tokenizer_class_ == "LlamaTokenizer" ? 1 : 0;
|
||||
eos_token_id_ = encoder.GetEncoder().GetTokenId(tok_config.eos_token_);
|
||||
|
||||
tok_config_ = ptr_config;
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
|
||||
std::string result;
|
||||
for (size_t pos = 0;; pos += search.length()) {
|
||||
auto new_pos = s.find(search, pos);
|
||||
if (new_pos == std::string::npos) {
|
||||
result += s.substr(pos, s.size() - pos);
|
||||
break;
|
||||
}
|
||||
result += s.substr(pos, new_pos - pos);
|
||||
result += replace;
|
||||
pos = new_pos;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool IsSpmByteWord(std::string_view word) {
|
||||
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
|
||||
}
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id,
|
||||
std::string& token,
|
||||
bool skip_special_tokens,
|
||||
bool& f_special_last) const {
|
||||
bool f_special = all_special_ids_.count(id) ? true : false;
|
||||
if (skip_special_tokens && f_special) {
|
||||
f_special_last = f_special;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (added_tokens_.count(id)) {
|
||||
const std::string ws = added_tokens_.at(id);
|
||||
token = (std::string)ws;
|
||||
} else if (static_cast<size_t>(id) < arr_vocab_.size()) {
|
||||
const auto str = ustring(arr_vocab_[id]);
|
||||
for (auto wchr : str) {
|
||||
if (byte_decoder_.count(wchr) == 0 && wchr <= 0xFF) {
|
||||
token.push_back(gsl::narrow<unsigned char>(wchr));
|
||||
} else {
|
||||
unsigned char uchr = byte_decoder_.at(wchr);
|
||||
token.push_back(uchr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (skip_special_tokens) {
|
||||
f_special_last = f_special;
|
||||
return {};
|
||||
} else {
|
||||
token = unk_token_;
|
||||
}
|
||||
}
|
||||
|
||||
// remove the end_of_word_suffix like </w> or </s> etc.
|
||||
if (end_of_word_suffix_.size() > 0) {
|
||||
if (token.size() >= end_of_word_suffix_.size() &&
|
||||
token.substr(token.size() - end_of_word_suffix_.size()) == end_of_word_suffix_) {
|
||||
token = token.substr(0, token.size() - end_of_word_suffix_.size());
|
||||
token += ' ';
|
||||
}
|
||||
}
|
||||
|
||||
f_special_last = f_special;
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
|
||||
const char spm_underscore[] = "\xe2\x96\x81";
|
||||
|
||||
std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
|
||||
bool f_special = false;
|
||||
if (piece.empty() || all_special_ids_.count(id)) {
|
||||
token = "";
|
||||
f_special = true;
|
||||
} else if (IsSpmByteWord(piece)) {
|
||||
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
|
||||
token = {static_cast<char>(strtol(buf, NULL, 16))};
|
||||
} else {
|
||||
token = ReplaceAll(piece, spm_underscore, " ");
|
||||
}
|
||||
|
||||
if (!token.empty() && token[0] == ' ' && f_special_last && add_dummy_prefix_) {
|
||||
token = token.substr(1);
|
||||
}
|
||||
|
||||
f_special_last = f_special;
|
||||
return {};
|
||||
}
|
||||
|
||||
static bool IsSpmTokenizer(const std::string& tok_class) {
|
||||
return tok_class == "GemmaTokenizer" || tok_class == "LlamaTokenizer";
|
||||
}
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
|
||||
auto bpe_state = *state;
|
||||
std::unique_ptr<BPEDecoderState> bpe_state_ptr;
|
||||
bool is_first = false;
|
||||
if (bpe_state == nullptr) {
|
||||
bpe_state_ptr = std::make_unique<BPEDecoderState>();
|
||||
bpe_state = bpe_state_ptr.get();
|
||||
is_first = true;
|
||||
}
|
||||
|
||||
bool f_special = bpe_state->f_special_last; // [Spm]Id2Token needs the last state
|
||||
bool f_special_last = bpe_state->f_special_last;
|
||||
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
|
||||
? SpmId2Token(id, token, f_special)
|
||||
: Id2Token(id, token, true /* tok_config_.skip_special_tokens_ */, f_special);
|
||||
|
||||
if (status.IsOk()) {
|
||||
if (bpe_state_ptr) {
|
||||
*state = bpe_state_ptr.release();
|
||||
}
|
||||
|
||||
if (tok_config_->clean_up_tokenization_spaces_) {
|
||||
if (f_special && (is_first && !f_special_last)) {
|
||||
token = std::string(" ") + token;
|
||||
}
|
||||
|
||||
if (f_special && id != eos_token_id_) {
|
||||
token.push_back(' ');
|
||||
}
|
||||
} // end case of whitespace_token_
|
||||
|
||||
if (!bpe_state->incomplete_utf8_.empty()) {
|
||||
token = bpe_state->incomplete_utf8_ + token;
|
||||
bpe_state->incomplete_utf8_.clear();
|
||||
} else {
|
||||
if (!token.empty() && ustring::UTF8Len(token.front()) > token.size()) {
|
||||
bpe_state->incomplete_utf8_ = token;
|
||||
token = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bpe_state->f_special_last = f_special;
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
if (ids_dim.size() > 1) {
|
||||
output_dim.resize(ids_dim.size() - 1);
|
||||
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
|
||||
}
|
||||
|
||||
size_t seq_len = ids_dim.back();
|
||||
size_t string_batch = ids.NumberOfElement() / seq_len;
|
||||
std::vector<std::string> decoded_strings;
|
||||
decoded_strings.reserve(string_batch);
|
||||
|
||||
for (auto n = string_batch; n > 0; n--) {
|
||||
bool f_special_last = false;
|
||||
std::string text;
|
||||
|
||||
for (size_t tok_idx = 0; tok_idx < seq_len; ++tok_idx) {
|
||||
const auto id = ort_extensions::narrow<extTokenId_t>(*(p_ids + tok_idx));
|
||||
std::string decoded_token;
|
||||
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
|
||||
? SpmId2Token(id, decoded_token, f_special_last)
|
||||
: Id2Token(id, decoded_token, true, f_special_last);
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
bool f_special = all_special_ids_.count(id) ? true : false;
|
||||
|
||||
if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
|
||||
text.push_back(' ');
|
||||
}
|
||||
|
||||
text.append(decoded_token);
|
||||
|
||||
if (whitespace_token_ && f_special && tok_idx != seq_len - 1) {
|
||||
text.push_back(' ');
|
||||
}
|
||||
}
|
||||
|
||||
if (tok_config_->tokenizer_class_.find("CLIP") == 0 && !text.empty() && text.back() == ' ') {
|
||||
text.pop_back();
|
||||
}
|
||||
|
||||
decoded_strings.emplace_back(std::move(text));
|
||||
p_ids += seq_len;
|
||||
}
|
||||
|
||||
output.SetStringOutput(decoded_strings, output_dim);
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
void CreateByteDecoder(const ort_extensions::BpeModel& /* bpe_model */) {
|
||||
char32_t index = 256;
|
||||
for (char32_t i = 0; i < 256; ++i) {
|
||||
/*
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
*/
|
||||
if ((/* i >= 0 && */ i < 33) || (i >= 127 && i < 161) || (i == 173)) {
|
||||
byte_decoder_[index++] = gsl::narrow<unsigned char>(i);
|
||||
} else {
|
||||
byte_decoder_[i] = gsl::narrow<unsigned char>(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
extTokenId_t eos_token_id_{0};
|
||||
bool add_dummy_prefix_ = false;
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
|
||||
};
|
|
@ -19,10 +19,13 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
#include "bpe_utils.hpp"
|
||||
#include "trietree.hpp"
|
||||
#include "bpe_types.h"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class BpeModel {
|
||||
using json = nlohmann::json;
|
||||
|
||||
public:
|
||||
BpeModel() = default;
|
||||
|
||||
|
@ -49,7 +52,7 @@ class BpeModel {
|
|||
bool spm_converted) {
|
||||
nlohmann::json tok_json;
|
||||
vocab_stream >> tok_json;
|
||||
vocab_map_ = std::move(tok_json.get<std::unordered_map<std::string, uint32_t>>());
|
||||
tok_json.get_to(vocab_map_);
|
||||
|
||||
auto it = vocab_map_.find(unk_token);
|
||||
if (it != vocab_map_.end()) {
|
||||
|
@ -123,6 +126,75 @@ class BpeModel {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Load(const json& bpe_model,
|
||||
const char* /* special_tokens */,
|
||||
bool spm_converted) {
|
||||
const json& vocab_json = bpe_model["vocab"];
|
||||
const json& merges_json = bpe_model["merges"];
|
||||
vocab_json.get_to(vocab_map_);
|
||||
auto it = bpe_model.find("unk_token");
|
||||
if (it != bpe_model.end() && it->is_string()) {
|
||||
auto ukt = it->get<std::string>();
|
||||
auto it_word = vocab_map_.find(ukt);
|
||||
if (it_word != vocab_map_.end()) {
|
||||
unk_id_ = it_word->second;
|
||||
}
|
||||
}
|
||||
|
||||
it = bpe_model.find("end_of_word_suffix");
|
||||
if (it != bpe_model.end() && it->is_string()) {
|
||||
end_of_word_suffix_ = it->get<std::string>();
|
||||
}
|
||||
|
||||
if (spm_converted) {
|
||||
UpdateSpmByteToken(vocab_map_);
|
||||
} else {
|
||||
CreateByteEncoder();
|
||||
}
|
||||
|
||||
uint32_t index = 0;
|
||||
auto merge_item = merges_json.begin();
|
||||
while (merge_item != merges_json.end()) {
|
||||
std::string line = merge_item.value();
|
||||
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
|
||||
if (line.empty()) continue;
|
||||
if ((line[0] == '#') && (index == 0)) continue;
|
||||
auto pos = line.find(' ');
|
||||
if (pos == std::string::npos) {
|
||||
return {
|
||||
kOrtxErrorCorruptData,
|
||||
"Cannot know how to parse line: " + line,
|
||||
};
|
||||
}
|
||||
std::string w1 = line.substr(0, pos);
|
||||
std::string w2 = line.substr(pos + 1);
|
||||
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
|
||||
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
|
||||
token_length -= 4;
|
||||
}
|
||||
auto iw1 = GetTokenId(w1);
|
||||
auto iw2 = GetTokenId(w2);
|
||||
auto iww = GetTokenId(w1 + w2);
|
||||
BpeNode value{iww, index++, token_length};
|
||||
bpe_rank_[GetRankKey(iw1, iw2)] = value;
|
||||
|
||||
merge_item++;
|
||||
}
|
||||
|
||||
id2token_map_.resize(vocab_map_.size());
|
||||
for (const auto& [t, i] : vocab_map_) {
|
||||
if (i > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
|
||||
continue; // safe purpose.
|
||||
}
|
||||
if (i > id2token_map_.size()) {
|
||||
id2token_map_.resize(static_cast<size_t>(i) + 1);
|
||||
}
|
||||
id2token_map_[i] = t;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadAddedTokens(const char* added_tokens) {
|
||||
int id = bpe::kInvalidTokenId;
|
||||
std::istringstream strm_tokens(added_tokens);
|
||||
|
@ -149,6 +221,18 @@ class BpeModel {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
|
||||
for (const auto& token : added_tokens) {
|
||||
added_tokens_.Add(ustring(token.content_), 0, token.id_);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::string> BuildDecoder() const {
|
||||
return id2token_map_;
|
||||
}
|
||||
|
||||
// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
|
||||
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
|
||||
// split by added tokens
|
||||
|
@ -223,15 +307,19 @@ class BpeModel {
|
|||
return byte_encoder_;
|
||||
}
|
||||
|
||||
uint32_t GetTokenId(const std::string& key) {
|
||||
uint32_t GetTokenId(const std::string& key) const {
|
||||
auto it = vocab_map_.find(key);
|
||||
if (it != end(vocab_map_)) {
|
||||
if (it != vocab_map_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return unk_id_;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& GetEndOfWordSuffix() const {
|
||||
return end_of_word_suffix_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BpeNode {
|
||||
uint32_t id;
|
||||
|
@ -260,6 +348,7 @@ class BpeModel {
|
|||
}
|
||||
|
||||
private:
|
||||
std::string end_of_word_suffix_;
|
||||
std::map<uint64_t, BpeNode> bpe_rank_;
|
||||
|
||||
uint32_t byte_encoder_[256] = {};
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ort_extensions {
|
||||
class BpeModel;
|
||||
|
||||
namespace bpe {
|
||||
|
||||
struct AddedToken final {
|
||||
uint32_t id_{};
|
||||
std::string token_type_;
|
||||
std::string content_;
|
||||
bool lstrip_{};
|
||||
bool normalized_{};
|
||||
bool rstrip_{};
|
||||
bool single_word_{};
|
||||
bool special_{};
|
||||
};
|
||||
|
||||
class TokenJsonConfig; // forward declaration
|
||||
|
||||
} // namespace bpe
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,18 @@
|
|||
LIBRARY "ortextensions.dll"
|
||||
EXPORTS
|
||||
RegisterCustomOps @1
|
||||
AddExternalCustomOp @2
|
||||
GetActiveOrtAPIVersion @3
|
||||
OrtxGetAPIVersion @4
|
||||
OrtxGetLastErrorMessage @5
|
||||
OrtxCreate @6
|
||||
OrtxDispose @7
|
||||
OrtxCreateTokenizer @8
|
||||
OrtxTokenize @9
|
||||
OrtxDetokenize @10
|
||||
OrtxDetokenize1D @11
|
||||
OrtxDetokenizeCached @12
|
||||
OrtxStringArrayGetBatch @13
|
||||
OrtxStringArrayGetItem @14
|
||||
OrtxTokenId2DArrayGetBatch @15
|
||||
OrtxTokenId2DArrayGetItem @16
|
|
@ -0,0 +1,385 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cstdarg>
|
||||
#include <filesystem>
|
||||
#include <algorithm>
|
||||
|
||||
#include "tokenizer_impl.h"
|
||||
|
||||
namespace ort_extensions {
|
||||
class TokenId2DArray : public OrtxObjectImpl {
|
||||
public:
|
||||
TokenId2DArray() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenId2DArray) {}
|
||||
~TokenId2DArray() override = default;
|
||||
|
||||
void SetTokenIds(std::vector<std::vector<extTokenId_t>>&& token_ids) {
|
||||
token_ids_ = token_ids;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<std::vector<extTokenId_t>>& token_ids() const {
|
||||
return token_ids_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::vector<extTokenId_t>> token_ids_;
|
||||
};
|
||||
|
||||
class StringArray : public OrtxObjectImpl {
|
||||
public:
|
||||
StringArray() : OrtxObjectImpl(extObjectKind_t::kOrtxKindStringArray) {}
|
||||
~StringArray() override = default;
|
||||
|
||||
void SetStrings(std::vector<std::string>&& strings) {
|
||||
strings_ = strings;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<std::string>& strings() const {
|
||||
return strings_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> strings_;
|
||||
};
|
||||
|
||||
class DetokenizerCache : public OrtxObjectImpl {
|
||||
public:
|
||||
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
|
||||
~DetokenizerCache() override = default;
|
||||
|
||||
std::unique_ptr<BPEDecoderState> decoder_state_{};
|
||||
std::string last_text_{}; // last detokenized text
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
thread_local std::string last_error_message;
|
||||
|
||||
OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const {
|
||||
if (ext_kind_ != static_cast<int>(kind)) {
|
||||
return {extError_t::kOrtxErrorInvalidArgument,
|
||||
"Object is not an instance of the requested type"};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
struct ReturnableStatus {
|
||||
ReturnableStatus() = default;
|
||||
ReturnableStatus(OrtxStatus&& status) : status_(status) {}
|
||||
~ReturnableStatus() {
|
||||
if (!status_.IsOk()) {
|
||||
last_error_message = status_.Message();
|
||||
}
|
||||
}
|
||||
ReturnableStatus& operator=(OrtxStatus&& status) {
|
||||
status_ = status;
|
||||
return *this;
|
||||
}
|
||||
bool IsOk() const { return status_.IsOk(); }
|
||||
extError_t Code() const { return status_.Code(); }
|
||||
|
||||
private:
|
||||
OrtxStatus status_;
|
||||
};
|
||||
|
||||
int ORTX_API_CALL OrtxGetAPIVersion() {
|
||||
return API_VERSION;
|
||||
}
|
||||
|
||||
const char* OrtxGetLastErrorMessage() {
|
||||
return last_error_message.c_str();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, ...) {
|
||||
if (object == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
if (kind == extObjectKind_t::kOrtxKindUnknown) {
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
va_list args;
|
||||
va_start(args, object);
|
||||
|
||||
if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) {
|
||||
*object = std::make_unique<DetokenizerCache>().release();
|
||||
} else if (kind == extObjectKind_t::kOrtxKindTokenizer) {
|
||||
return OrtxCreateTokenizer(static_cast<OrtxTokenizer**>(object), va_arg(args, const char*));
|
||||
}
|
||||
|
||||
va_end(args);
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer,
|
||||
const char* tokenizer_path) {
|
||||
// test if the tokenizer_path is a valid directory
|
||||
if (tokenizer_path == nullptr) {
|
||||
last_error_message = "The tokenizer data directory is null";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
if (!std::filesystem::is_directory(tokenizer_path)) {
|
||||
last_error_message = std::string("Cannot find the directory of ") + tokenizer_path;
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
ReturnableStatus status;
|
||||
// auto ptr = ort_extensions::CreateTokenizer(tokenizer_path, "", &status);
|
||||
auto ptr = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
status = ptr->Load(tokenizer_path);
|
||||
if (status.IsOk()) {
|
||||
*tokenizer = static_cast<OrtxTokenizer*>(ptr.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Dispose(T* object) {
|
||||
auto token_ptr = static_cast<T*>(object);
|
||||
std::unique_ptr<T> ptr(token_ptr);
|
||||
ptr.reset();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) {
|
||||
if (object == nullptr || *object == nullptr) {
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto Ortx_object = static_cast<OrtxObjectImpl*>(*object);
|
||||
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindUnknown) {
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
|
||||
Dispose(static_cast<ort_extensions::StringArray*>(*object));
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) {
|
||||
Dispose(static_cast<ort_extensions::TokenId2DArray*>(*object));
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) {
|
||||
Dispose(static_cast<ort_extensions::DetokenizerCache*>(*object));
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenizer) {
|
||||
Dispose(static_cast<ort_extensions::TokenizerImpl*>(*object));
|
||||
}
|
||||
|
||||
*object = nullptr;
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
|
||||
return OrtxDispose(&object);
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
|
||||
const char* input[], size_t batch_size, OrtxTokenId2DArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
|
||||
ReturnableStatus status =
|
||||
token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
std::vector<std::vector<extTokenId_t>> t_ids;
|
||||
std::vector<std::string_view> input_view;
|
||||
std::transform(input, input + batch_size, std::back_inserter(input_view),
|
||||
[](const char* str) { return std::string_view(str); });
|
||||
|
||||
status = token_ptr->Tokenize(input_view, t_ids);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto result = std::make_unique<ort_extensions::TokenId2DArray>().release();
|
||||
result->SetTokenIds(std::move(t_ids));
|
||||
*output = static_cast<OrtxTokenId2DArray*>(result);
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
|
||||
const OrtxTokenId2DArray* input, OrtxStringArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto input_2d = static_cast<const TokenId2DArray*>(input);
|
||||
status = input_2d->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
std::vector<span<extTokenId_t const>> t_ids;
|
||||
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(),
|
||||
std::back_inserter(t_ids),
|
||||
[](const std::vector<extTokenId_t>& vec) {
|
||||
return span<extTokenId_t const>(vec.data(), vec.size());
|
||||
});
|
||||
|
||||
std::vector<std::string> output_text;
|
||||
status = token_ptr->Detokenize(t_ids, output_text);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto result = std::make_unique<ort_extensions::StringArray>().release();
|
||||
result->SetStrings(std::move(output_text));
|
||||
*output = static_cast<OrtxStringArray*>(result);
|
||||
|
||||
return extError_t();
|
||||
;
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer,
|
||||
const extTokenId_t* input,
|
||||
size_t len,
|
||||
OrtxStringArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
std::vector<span<extTokenId_t const>> t_ids = {{input, len}};
|
||||
std::vector<std::string> output_text;
|
||||
status = token_ptr->Detokenize(t_ids, output_text);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto result = std::make_unique<ort_extensions::StringArray>().release();
|
||||
result->SetStrings(std::move(output_text));
|
||||
*output = static_cast<OrtxStringArray*>(result);
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxStringArrayGetBatch(const OrtxStringArray* string_array, size_t* length) {
|
||||
if (string_array == nullptr || length == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const StringArray*>(string_array);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindStringArray));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
*length = token_ptr->strings().size();
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxStringArrayGetItem(const OrtxStringArray* string_array, size_t index, const char** item) {
|
||||
if (string_array == nullptr || item == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const StringArray*>(string_array);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindStringArray));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
if (index >= token_ptr->strings().size()) {
|
||||
last_error_message = "the index is out of range";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
*item = token_ptr->strings()[index].c_str();
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* token_id_2d_array, size_t* length) {
|
||||
if (token_id_2d_array == nullptr || length == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_2d_ptr = static_cast<const TokenId2DArray*>(token_id_2d_array);
|
||||
ReturnableStatus status(token_2d_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
*length = token_2d_ptr->token_ids().size();
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array,
|
||||
size_t index, const extTokenId_t** item, size_t* length) {
|
||||
if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const TokenId2DArray*>(token_id_2d_array);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenId2DArray));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
if (index >= token_ptr->token_ids().size()) {
|
||||
last_error_message = "the index is out of range";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
*item = token_ptr->token_ids()[index].data();
|
||||
*length = token_ptr->token_ids()[index].size();
|
||||
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer,
|
||||
OrtxDetokenizerCache* cache,
|
||||
extTokenId_t next_id, const char** text_out) {
|
||||
if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) {
|
||||
last_error_message = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
|
||||
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto cache_ptr = static_cast<DetokenizerCache*>(cache);
|
||||
status = ReturnableStatus(cache_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindDetokenizerCache));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
cache_ptr->last_text_.clear();
|
||||
status = ReturnableStatus(token_ptr->Id2Token(next_id, cache_ptr->last_text_, cache_ptr->decoder_state_));
|
||||
if (status.IsOk()) {
|
||||
*text_out = cache_ptr->last_text_.c_str();
|
||||
}
|
||||
|
||||
return status.Code();
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
#include "bpe_decoder.hpp"
|
||||
|
||||
#include "tokenizer_impl.h"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
class SimpleAllocator : public ortc::IAllocator {
|
||||
public:
|
||||
void* Alloc(size_t size) override {
|
||||
return malloc(size);
|
||||
}
|
||||
|
||||
void Free(void* p) override {
|
||||
if (p) {
|
||||
free(p);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static SimpleAllocator g_allocator;
|
||||
|
||||
TokenizerImpl::TokenizerImpl() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer){};
|
||||
TokenizerImpl::~TokenizerImpl(){};
|
||||
|
||||
OrtxStatus TokenizerImpl::Load(const std::string& dir) {
|
||||
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
|
||||
auto status = tok_config_->Load(dir);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
tokenizer_ = std::make_unique<JsonFastTokenizer>();
|
||||
// load the tokenizer from a config
|
||||
status = tokenizer_->Load(*tok_config_);
|
||||
if (status.IsOk()) {
|
||||
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
|
||||
status = detokenizer_->Load(tok_config_, *tokenizer_);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::BatchEncode(
|
||||
const std::vector<std::string_view>& input,
|
||||
std::vector<std::vector<extTokenId_t>>& t_ids) const {
|
||||
for (const auto& s : input) {
|
||||
ortc::Tensor<int64_t> ts_output(&g_allocator);
|
||||
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
|
||||
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
std::vector<extTokenId_t> ids(ts_output.NumberOfElement());
|
||||
std::transform(ts_output.Data(), ts_output.Data() + ts_output.NumberOfElement(), ids.begin(),
|
||||
[](int64_t v) { return static_cast<extTokenId_t>(v); });
|
||||
t_ids.emplace_back(std::move(ids));
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids,
|
||||
std::vector<std::string>& t_text) const {
|
||||
for (const auto& s : t_ids) {
|
||||
std::vector<int64_t> ids(s.size());
|
||||
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
|
||||
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
|
||||
ortc::Tensor<std::string> ts_output;
|
||||
OrtxStatus status = detokenizer_->Compute(ts_input, ts_output);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
t_text.emplace_back(ts_output.AsScalar());
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
|
||||
return detokenizer_->Id2Token(id, token, state);
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
#include "bpe_kernels.h"
|
||||
#include "bpe_json.hpp"
|
||||
#include "bpe_streaming.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
class OrtxObjectImpl : public OrtxObject {
|
||||
public:
|
||||
explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() {
|
||||
ext_kind_ = static_cast<int>(kind);
|
||||
};
|
||||
virtual ~OrtxObjectImpl() = default;
|
||||
|
||||
[[nodiscard]] OrtxStatus IsInstanceOf(extObjectKind_t kind) const;
|
||||
[[nodiscard]] extObjectKind_t ortx_kind() const {
|
||||
if (ext_kind_ < static_cast<int>(extObjectKind_t::kOrtxKindBegin) ||
|
||||
ext_kind_ >= static_cast<int>(extObjectKind_t::kOrtxKindEnd)) {
|
||||
return extObjectKind_t::kOrtxKindUnknown;
|
||||
}
|
||||
return static_cast<extObjectKind_t>(ext_kind_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class span {
|
||||
public:
|
||||
using value_type = std::remove_cv_t<T>;
|
||||
|
||||
span(T* d, size_t s) : data_(d), size_(s) {}
|
||||
span(std::vector<value_type>& v) {
|
||||
data_ = v.data();
|
||||
size_ = v.size();
|
||||
}
|
||||
|
||||
T* data() const { return data_; }
|
||||
[[nodiscard]] size_t size() const { return size_; }
|
||||
T* begin() const { return data_; }
|
||||
T* end() const { return data_ + size_; }
|
||||
|
||||
private:
|
||||
T* data_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
|
||||
class TokenizerImpl : public OrtxObjectImpl {
|
||||
public:
|
||||
TokenizerImpl();
|
||||
virtual ~TokenizerImpl();
|
||||
|
||||
public:
|
||||
OrtxStatus Load(const std::string& dir);
|
||||
|
||||
OrtxStatus Tokenize(const std::vector<std::string_view>& input,
|
||||
std::vector<std::vector<extTokenId_t>>& t_ids) const {
|
||||
return BatchEncode(input, t_ids);
|
||||
}
|
||||
|
||||
OrtxStatus Detokenize(const std::vector<span<extTokenId_t const>>& t_ids,
|
||||
std::vector<std::string>& t_text) const {
|
||||
return BatchDecode(t_ids, t_text);
|
||||
}
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
|
||||
BPEDecoderState* state_ptr = cache.get();
|
||||
OrtxStatus status = Id2Token(id, token, &state_ptr);
|
||||
if (status.IsOk()) {
|
||||
if (state_ptr != cache.get()) {
|
||||
cache.reset(state_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus BatchEncode(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
||||
|
||||
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
|
||||
|
||||
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state ) const;
|
||||
|
||||
private:
|
||||
std::string tokenizer_dir_;
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
|
||||
std::unique_ptr<JsonFastTokenizer> tokenizer_;
|
||||
std::unique_ptr<BpeStreamingDecoder> detokenizer_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": true,
|
||||
"__type": "AddedToken"
|
||||
},
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": true,
|
||||
"__type": "AddedToken"
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": true,
|
||||
"__type": "AddedToken"
|
||||
},
|
||||
"pad_token": "<|endoftext|>",
|
||||
"add_prefix_space": false,
|
||||
"errors": "replace",
|
||||
"do_lower_case": true,
|
||||
"name_or_path": "openai/clip-vit-base-patch32",
|
||||
"model_max_length": 77,
|
||||
"special_tokens_map_file": "/home/suraj/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
|
||||
"tokenizer_class": "CLIPTokenizer"
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"additional_special_tokens": [
|
||||
{
|
||||
"content": "<start_of_turn>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
{
|
||||
"content": "<end_of_turn>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
],
|
||||
"bos_token": {
|
||||
"content": "<bos>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<eos>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,70 @@
|
|||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<eos>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<bos>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"106": {
|
||||
"content": "<start_of_turn>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"107": {
|
||||
"content": "<end_of_turn>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<start_of_turn>",
|
||||
"<end_of_turn>"
|
||||
],
|
||||
"bos_token": "<bos>",
|
||||
"chat_template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<eos>",
|
||||
"legacy": null,
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "GemmaTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": null,
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,323 @@
|
|||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"50256": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"50257": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50258": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50259": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50260": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50261": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50262": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50263": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50264": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50265": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50266": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50267": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50268": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50269": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50270": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50271": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50272": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50273": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50274": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50275": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50276": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50277": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50278": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50279": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50280": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50281": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50282": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50283": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50284": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50285": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50286": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50287": {
|
||||
"content": "\t\t\t\t\t\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50288": {
|
||||
"content": "\t\t\t\t\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50289": {
|
||||
"content": "\t\t\t\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50290": {
|
||||
"content": "\t\t\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50291": {
|
||||
"content": "\t\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50292": {
|
||||
"content": "\t\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50293": {
|
||||
"content": "\t\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50294": {
|
||||
"content": "\t\t",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"bos_token": "<|endoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"model_max_length": 2048,
|
||||
"tokenizer_class": "CodeGenTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -12,7 +12,11 @@ void convert_test(const char* const_str) {
|
|||
std::string string(const_str);
|
||||
const std::string const_string(const_str);
|
||||
|
||||
#if _WIN32
|
||||
auto str = std::shared_ptr<char>(_strdup(const_str));
|
||||
#else
|
||||
auto str = std::shared_ptr<char>(strdup(const_str));
|
||||
#endif
|
||||
ustring char_construct(str.get());
|
||||
EXPECT_EQ(const_string, std::string(char_construct));
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ class TestMathOpString(unittest.TestCase):
|
|||
inv_mat = np.linalg.inv(mat)
|
||||
ort_inv = OrtPyFunction.from_customop('Inverse')
|
||||
act_mat = ort_inv(mat)
|
||||
self.assertTrue(np.allclose(inv_mat, act_mat, rtol=1.e-3))
|
||||
np.testing.assert_allclose(inv_mat, act_mat, rtol=1.e-2, atol=1.e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include <locale>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
#include "bpe_kernels.h"
|
||||
|
||||
TEST(bbpe_tokenizer, test_encoder) {
|
||||
EXPECT_EQ(0, ORT_OK);
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/* C-only file, to verify the header file C compatibility */
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "c_only_test.h"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define strdup _strdup
|
||||
#endif
|
||||
|
||||
|
||||
extError_t tokenize_text(OrtxTokenizer* tokenizer, const char* text, char** decoded_text) {
|
||||
OrtxTokenId2DArray* tok_2d_output = NULL;
|
||||
const char* tok_input[] = {text};
|
||||
extError_t err = OrtxTokenize(tokenizer, tok_input, 1, &tok_2d_output);
|
||||
if (err != kOrtxOK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
size_t length = 0;
|
||||
const extTokenId_t* token_ids = NULL;
|
||||
OrtxTokenId2DArrayGetItem(tok_2d_output, 0, &token_ids, &length);
|
||||
|
||||
OrtxStringArray* detok_output = NULL;
|
||||
err = OrtxDetokenize1D(tokenizer, token_ids, length, &detok_output);
|
||||
if (err != kOrtxOK) {
|
||||
ORTX_DISPOSE(tok_2d_output);
|
||||
return err;
|
||||
}
|
||||
|
||||
const char* decoded_str = NULL;
|
||||
OrtxStringArrayGetItem(detok_output, 0, &decoded_str);
|
||||
*decoded_text = strdup(decoded_str);
|
||||
|
||||
ORTX_DISPOSE(tok_2d_output);
|
||||
ORTX_DISPOSE(detok_output);
|
||||
return kOrtxOK;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
#endif // __cplusplus
|
||||
extError_t tokenize_text(OrtxTokenizer* tokenizer, const char* text, char** decoded_text);
|
|
@ -0,0 +1,271 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include <locale>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
|
||||
#include "c_only_test.h"
|
||||
#include "shared/lib/tokenizer_impl.h"
|
||||
|
||||
static void DumpTokenIds(const std::vector<std::vector<extTokenId_t>>& token_ids) {
|
||||
#ifdef _DEBUG
|
||||
for (const auto& tokens : token_ids) {
|
||||
for (const auto& token : tokens) {
|
||||
std::cout << token << " ";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(CApiTest, ApiTest) {
|
||||
int ver = OrtxGetAPIVersion();
|
||||
EXPECT_GT(ver, 0);
|
||||
OrtxTokenizer* tokenizer = NULL;
|
||||
extError_t err = OrtxCreateTokenizer(&tokenizer, "data/tiktoken");
|
||||
EXPECT_EQ(err, kOrtxOK);
|
||||
|
||||
const char* input = "This is a test";
|
||||
char* decoded_text = NULL;
|
||||
err = tokenize_text(tokenizer, input, &decoded_text);
|
||||
EXPECT_EQ(err, kOrtxOK);
|
||||
EXPECT_STREQ(decoded_text, input);
|
||||
free(decoded_text);
|
||||
}
|
||||
|
||||
TEST(CApiTest, StreamApiTest) {
|
||||
OrtxTokenizer* tokenizer = NULL;
|
||||
extError_t err = OrtxCreate(kOrtxKindTokenizer, &tokenizer, "data/llama2");
|
||||
EXPECT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxDetokenizerCache* detok_cache = NULL;
|
||||
err = OrtxCreate(kOrtxKindDetokenizerCache, &detok_cache);
|
||||
EXPECT_EQ(err, kOrtxOK);
|
||||
|
||||
extTokenId_t token_ids[] = {1, 910, 338, 263, 1243, 322, 278, 1473, 697, 29889, 29871, 35};
|
||||
for (size_t i = 0; i < sizeof(token_ids) / sizeof(token_ids[0]); i++) {
|
||||
const char* token = NULL;
|
||||
err = OrtxDetokenizeCached(tokenizer, detok_cache, token_ids[i], &token);
|
||||
#ifdef _DEBUG
|
||||
std::cout << token;
|
||||
#endif
|
||||
EXPECT_EQ(err, kOrtxOK);
|
||||
}
|
||||
|
||||
#ifdef _DEBUG
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
|
||||
OrtxDisposeOnly(detok_cache);
|
||||
OrtxDispose(&tokenizer);
|
||||
}
|
||||
|
||||
|
||||
TEST(OrtxTokenizerTest, ClipTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/clip");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
std::vector<std::string_view> input = {"this is a test", "the second one"};
|
||||
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), 2);
|
||||
EXPECT_EQ(token_ids[0].size(), 6);
|
||||
EXPECT_EQ(token_ids[1].size(), 5);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, TicTokenTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/tiktoken");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {128000, 2028, 374, 264, 1296};
|
||||
std::vector<std::string_view> input = {"This is a test", "the second one"};
|
||||
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), 2);
|
||||
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
|
||||
EXPECT_EQ(token_ids[1].size(), 4);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
|
||||
TEST(OrtxTokenizerTest, GemmaTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/gemma");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string_view> input = {
|
||||
"I like walking my cute dog\n and\x17 then",
|
||||
"生活的真谛是",
|
||||
"\t\t\t\t \n\n61",
|
||||
"Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2, 235285, 1154, 10350, 970, 9786, 5929, 108, 578, 240, 1492};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_1 = {2, 122182, 235710, 245467, 235427};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_2 = {2, 255971, 235248, 109, 235318, 235274};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_3 = {2, 6750, 1, 235265, 235248, 255969, 235248, 109, 4747, 139, 235335, 139,
|
||||
216311, 241316, 139, 239880, 235341, 144, 235269, 235248, 235274, 235284,
|
||||
235304, 235310, 235248, 235274, 235308, 235248, 235308, 235269, 235318, 235274};
|
||||
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), input.size());
|
||||
DumpTokenIds(token_ids);
|
||||
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
|
||||
EXPECT_EQ(token_ids[1], EXPECTED_IDS_1);
|
||||
EXPECT_EQ(token_ids[2], EXPECTED_IDS_2);
|
||||
EXPECT_EQ(token_ids[3], EXPECTED_IDS_3);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {
|
||||
EXPECTED_IDS_0, EXPECTED_IDS_1, EXPECTED_IDS_2, EXPECTED_IDS_3};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// std::cout << out_text[0] << std::endl;
|
||||
// std::cout << out_text[1] << std::endl;
|
||||
// std::cout << out_text[2] << std::endl;
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
EXPECT_EQ(out_text[1], input[1]);
|
||||
}
|
||||
|
||||
static const char* kPromptText = R"(```python
|
||||
def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""
|
||||
primes = []
|
||||
for num in range(2, n+1):
|
||||
is_prime = True
|
||||
for i in range(2, int(math.sqrt(num))+1):
|
||||
if num % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.append(num)
|
||||
print(primes)''')";
|
||||
|
||||
TEST(OrtxTokenizerTest, CodeGenTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/phi-2");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
const char* prompt_text = kPromptText;
|
||||
|
||||
std::vector<std::string_view> input = {prompt_text};
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), 1);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0]};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// std::cout << out_text[0] << std::endl;
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/phi-2");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
const char* prompt_text = kPromptText;
|
||||
|
||||
std::vector<std::string_view> input = {prompt_text};
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), 1);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
|
||||
token_ids[0] = {921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
status = tokenizer->Id2Token(token_id, token, decoder_cache);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// std::cout << token;
|
||||
text.append(token);
|
||||
}
|
||||
|
||||
// EXPECT_EQ(text, input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
|
||||
// test the llama2 tokenizer with BPE class, instead of sentencepiece wrapper.
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/llama2");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_TRUE(tokenizer != nullptr);
|
||||
|
||||
std::vector<std::string_view> input = {"This is a test and the second one. "};
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// Add an extra byte token for decoding tests
|
||||
token_ids[0].push_back(35); // <0x20>
|
||||
DumpTokenIds(token_ids);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
// std::cout << "\"";
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
auto status = tokenizer->Id2Token(token_id, token, decoder_cache);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// std::cout << token;
|
||||
text.append(token);
|
||||
}
|
||||
|
||||
// std::cout << "\"" << std::endl;
|
||||
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
|
||||
}
|
|
@ -19,7 +19,7 @@ void FixCurrentDir(const std::string& init_path = "") {
|
|||
auto cur_dir = std::filesystem::current_path();
|
||||
if (!init_path.empty()) {
|
||||
std::filesystem::path init_dir = init_path;
|
||||
cur_dir = init_dir.parent_path();
|
||||
cur_dir = std::filesystem::absolute(init_dir).parent_path();
|
||||
}
|
||||
|
||||
do {
|
||||
|
|
Загрузка…
Ссылка в новой задаче