diff --git a/CMakeLists.txt b/CMakeLists.txt index db5d5846..fe4551c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -362,16 +362,13 @@ if(OCOS_ENABLE_BLINGFIRE) endif() if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) - if(NOT TARGET nlohmann_json) - set(JSON_BuildTests OFF CACHE INTERNAL "") - message(STATUS "Fetch json") - include(json) - endif() + message(STATUS "Fetch json") + include(json) endif() if(_HAS_TOKENIZER) message(STATUS "Tokenizer needed.") - file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*") + file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*" "operators/tokenizer/*.hpp") list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC}) endif() @@ -412,8 +409,7 @@ target_include_directories(ocos_operators PUBLIC ${ONNXRUNTIME_INCLUDE_DIR} ${PROJECT_SOURCE_DIR}/includes ${PROJECT_SOURCE_DIR}/base - ${PROJECT_SOURCE_DIR}/operators - ${PROJECT_SOURCE_DIR}/operators/tokenizer) + ${PROJECT_SOURCE_DIR}/operators) set(ocos_libraries) set(OCOS_COMPILE_DEFINITIONS) @@ -424,6 +420,8 @@ endif() if(_HAS_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER) + target_include_directories(ocos_operators PUBLIC + ${PROJECT_SOURCE_DIR}/operators/tokenizer) endif() if(OCOS_ENABLE_TF_STRING) @@ -491,7 +489,7 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) list(APPEND ocos_libraries nlohmann_json::nlohmann_json) endif() -target_include_directories(ocos_operators PRIVATE ${GSL_INCLUDE_DIR}) +target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR}) list(APPEND ocos_libraries Microsoft.GSL::GSL) list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS) diff --git a/base/ocos.cc b/base/ocos.cc index 1d166506..a709a3e3 100644 --- a/base/ocos.cc +++ b/base/ocos.cc @@ -2,26 +2,7 @@ // Licensed under the MIT License. #include #include "ocos.h" - -bool BaseKernel::HasAttribute(const char* name) const noexcept { - size_t size; - std::string out; - // Crashes here. - OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size); - auto r = api_.GetErrorCode(status); - bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK); - if (has) { - api_.ReleaseStatus(status); - return has; - } - const char* error = api_.GetErrorMessage(status); - if (strstr(error, "No attribute") == error) { - api_.ReleaseStatus(status); - return false; - } - api_.ReleaseStatus(status); - return true; -} +#include "narrow.h" OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept { if (status == nullptr) { @@ -72,6 +53,17 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcep return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK; } +template <> +bool BaseKernel::TryToGetAttribute(const char* name, int& value) const noexcept { + int64_t origin_value = 0; + if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) { + return false; + } + + value = ort_extensions::narrow(origin_value); + return true; +} + template <> bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept { int64_t origin_value = 0; diff --git a/build.bat b/build.bat index 8918418c..bdd39a1d 100644 --- a/build.bat +++ b/build.bat @@ -7,6 +7,9 @@ for /f "tokens=* USEBACKQ" %%i in ( IF NOT DEFINED VSINSTALLDIR GOTO :NOT_FOUND +IF "%1" == "-A" GOTO :VSDEV_CMD +set GEN_PLATFORM=-A x64 + :VSDEV_CMD set GENERATOR="Visual Studio 16 2019" IF "%VisualStudioVersion:~0,2%" == "16" GOTO :START_BUILD @@ -15,7 +18,7 @@ set GENERATOR="Visual Studio 17 2022" :START_BUILD set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe" mkdir .\out\Windows\ 2>NUL -%cmake_exe% -G %GENERATOR% -A x64 %* -B out\Windows -S . +%cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S . IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL% %cmake_exe% --build out\Windows --config RelWithDebInfo IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL% diff --git a/cmake/externals/json.cmake b/cmake/externals/json.cmake index 772ce733..32f08567 100644 --- a/cmake/externals/json.cmake +++ b/cmake/externals/json.cmake @@ -2,6 +2,8 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.10.5) +set(JSON_BuildTests OFF CACHE INTERNAL "") + FetchContent_GetProperties(json) if(NOT json_POPULATED) FetchContent_Populate(json) diff --git a/cmake/externals/opencv.cmake b/cmake/externals/opencv.cmake index fdaa70f8..05f33845 100644 --- a/cmake/externals/opencv.cmake +++ b/cmake/externals/opencv.cmake @@ -111,6 +111,13 @@ if(IOS) set(CPU_BASELINE DETECT) endif() +if (MSVC AND CMAKE_GENERATOR_PLATFORM) + string(TOLOWER ${CMAKE_GENERATOR_PLATFORM} _GEN_PLATFORM) + if (${_GEN_PLATFORM} MATCHES "arm|arm64") + set(OPENCV_SKIP_SYSTEM_PROCESSOR_DETECTION ON) + endif() +endif() + FetchContent_Declare( opencv GIT_REPOSITORY https://github.com/opencv/opencv.git @@ -150,4 +157,3 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") set_target_properties(${p} PROPERTIES FOLDER "externals/opencv") endforeach() endif() - diff --git a/includes/ocos.h b/includes/ocos.h index 32be5a69..f57234eb 100644 --- a/includes/ocos.h +++ b/includes/ocos.h @@ -23,14 +23,12 @@ struct BaseKernel { BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) { } - bool HasAttribute(const char* name) const noexcept; - template bool TryToGetAttribute(const char* name, T& value) const noexcept; template - T TryToGetAttributeWithDefault(const char* name, T default_value) const noexcept { - T& result = default_value; + T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept { + T result = default_value; TryToGetAttribute(name, result); return result; } diff --git a/onnxruntime_extensions/_cuops.py b/onnxruntime_extensions/_cuops.py index c2621acd..5281b38e 100644 --- a/onnxruntime_extensions/_cuops.py +++ b/onnxruntime_extensions/_cuops.py @@ -55,6 +55,18 @@ class GPT2Tokenizer(CustomOp): ] +class BpeDecoder(CustomOp): + @classmethod + def get_inputs(cls): + return [ + cls.io_def("ids", onnx.TensorProto.INT64, [None]) + ] + + @classmethod + def get_outputs(cls): + return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])] + + class VectorToString(CustomOp): @classmethod @@ -387,3 +399,7 @@ Opdef.create(_argsort_op, op_type='ArgSort', inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64], outputs=[PyCustomOpDef.dt_int64]) + + +class CustomOpConverter: + pass diff --git a/onnxruntime_extensions/_ortapi2.py b/onnxruntime_extensions/_ortapi2.py index 8c55828b..064be66d 100644 --- a/onnxruntime_extensions/_ortapi2.py +++ b/onnxruntime_extensions/_ortapi2.py @@ -18,7 +18,9 @@ def get_opset_version_from_ort(): "1.9": 15, "1.10": 15, "1.11": 16, - "1.12": 17 + "1.12": 17, + "1.13": 17, + "1.14": 18, } ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) @@ -32,7 +34,8 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), ) else onnx.helper.make_model model = fn_mm(graph, opset_imports=[ onnx.helper.make_operatorsetid('ai.onnx', opset_version)]) - model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)]) + model.opset_import.extend( + [onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)]) return model @@ -54,14 +57,23 @@ class OrtPyFunction: self.default_inputs = {} def create_from_customop(self, op_type, *args, **kwargs): - graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs) + cvt = kwargs.get('cvt', None) + if cvt is None: + cvt = args[0] if len(args) > 0 and isinstance( + args[0], CustomOpConverter) else None + args = args[1:] + else: + del kwargs['cvt'] + + new_kwargs = kwargs if cvt is None else cvt(**kwargs) + graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs) self._bind(make_onnx_model(graph)) return self def add_default_input(self, **kwargs): inputs = { - ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \ - np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items() + ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else + np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items() } self.default_inputs.update(inputs) @@ -87,7 +99,8 @@ class OrtPyFunction: def _ensure_ort_session(self): if self.ort_session is None: - sess = _ort.InferenceSession(self.onnx_model.SerializeToString(), self.get_ort_session_options()) + sess = _ort.InferenceSession( + self.onnx_model.SerializeToString(), self.get_ort_session_options()) self.ort_session = sess return self.ort_session @@ -113,7 +126,8 @@ class OrtPyFunction: # an annoying bug is numpy by default is int32, while pytorch is int64. # so cast the input here automatically. feed[i_.name] = \ - ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x + ts_x.astype( + np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x idx += 1 # feed.update(kwargs) @@ -121,7 +135,8 @@ class OrtPyFunction: def __call__(self, *args, **kwargs): self._ensure_ort_session() - outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs)) + outputs = self.ort_session.run( + None, self._argument_map(*args, **kwargs)) return outputs[0] if len(outputs) == 1 else tuple(outputs) diff --git a/onnxruntime_extensions/cvt.py b/onnxruntime_extensions/cvt.py new file mode 100644 index 00000000..fc2a79cd --- /dev/null +++ b/onnxruntime_extensions/cvt.py @@ -0,0 +1,43 @@ +import json +from ._cuops import CustomOpConverter + + +class HFTokenizerConverter(CustomOpConverter): + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def bpe_tokenizer(self, **kwargs): + hf_gpt2_tokenizer = self.tokenizer + attrs = {'vocab': json.dumps( + hf_gpt2_tokenizer.encoder, separators=(',', ':'))} + sorted_merges = {v_: k_ for k_, + v_ in hf_gpt2_tokenizer.bpe_ranks.items()} + attrs['merges'] = '\n'.join("{} {}".format( + *sorted_merges[n_]) for n_ in range(len(sorted_merges))) + attrs.update(**kwargs) + return attrs + + def bpe_decoder(self, **kwargs): + decoder = self.tokenizer.decoder + id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)]) + # with open("id_vocab.txt", "w", encoding="utf-8") as f: + # f.write(id_vocab) + byte_decoder = self.tokenizer.byte_decoder + str_byte_decoder = "\n".join(["{}\t{}".format( + ord(_c), str(byte_decoder[_c])) for _c in byte_decoder]) + # with open("byte_decoder.txt", "w", encoding="utf-8") as f: + # f.write(str_byte_decoder) + all_special_ids = self.tokenizer.all_special_ids + added_tokens = self.tokenizer.added_tokens_decoder + str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids]) + str_added_tokens = "\n".join( + ["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens]) + kwargs.update({ + "id_vocab": id_vocab, + "byte_decoder": str_byte_decoder, + "added_tokens": str_added_tokens, + "all_special_ids": str_all_special_ids, + "skip_special_tokens": kwargs.get("skip_special_tokens", False) + }) + + return kwargs diff --git a/operators/text/op_ragged_tensor.cc b/operators/text/op_ragged_tensor.cc index c3c1fca3..f0d0e24e 100644 --- a/operators/text/op_ragged_tensor.cc +++ b/operators/text/op_ragged_tensor.cc @@ -81,7 +81,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensorToDense(api, info) { - missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute(&info, "missing_value") : -1; + missing_value_ = TryToGetAttributeWithDefault("missing_value", -1) ; } void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) { diff --git a/operators/text/re2_strings/string_regex_replace.cc b/operators/text/re2_strings/string_regex_replace.cc index 0785c793..9a0d3969 100644 --- a/operators/text/re2_strings/string_regex_replace.cc +++ b/operators/text/re2_strings/string_regex_replace.cc @@ -10,7 +10,7 @@ KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute(&info_, "global_replace") : 1; + global_replace_ = TryToGetAttributeWithDefault("global_replace",1); } void KernelStringRegexReplace::Compute(OrtKernelContext* context) { diff --git a/operators/tokenizer/blingfire_sentencebreaker.cc b/operators/tokenizer/blingfire_sentencebreaker.cc index 6346b563..a3800fa4 100644 --- a/operators/tokenizer/blingfire_sentencebreaker.cc +++ b/operators/tokenizer/blingfire_sentencebreaker.cc @@ -24,9 +24,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api model_ = std::shared_ptr(model_ptr, FreeModel); - if (HasAttribute("max_sentence")) { - max_sentence = static_cast(ort_.KernelInfoGetAttribute(&info, "max_sentence")); - } + max_sentence = TryToGetAttributeWithDefault("max_sentence", -1); } void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) { diff --git a/operators/tokenizer/bpe_decoder.hpp b/operators/tokenizer/bpe_decoder.hpp new file mode 100644 index 00000000..14b391fc --- /dev/null +++ b/operators/tokenizer/bpe_decoder.hpp @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ocos.h" +#include "ustring.h" +#include "narrow.h" +#include +#include +#include +#include +#include +#include +#include + +struct KernelBpeDecoder : public BaseKernel { + public: + KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { + std::string vocab = ort_.KernelInfoGetAttribute(&info, "id_vocab"); + if (vocab.empty()) { + ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT); + } + BuildIdVocab(vocab); + + std::string byte_decoder = ort_.KernelInfoGetAttribute(&info, "byte_decoder"); + if (byte_decoder.empty()) { + ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT); + } else { + auto um = ParseId2String(byte_decoder); + std::transform(um.begin(), um.end(), + std::inserter(byte_decoder_, byte_decoder_.end()), + [](const auto& p) { return std::make_pair(static_cast(p.first), + ort_extensions::narrow(std::stoul(p.second))); }); + } + + std::string added_tokens = TryToGetAttributeWithDefault("added_tokens", ""); + if (!added_tokens.empty()) { + auto um = ParseId2String(added_tokens); + added_tokens_ = std::map(um.begin(), um.end()); + } + + std::string all_special_ids = TryToGetAttributeWithDefault("all_special_ids", ""); + if (!all_special_ids.empty()) { + auto um = ParseId2String(all_special_ids); + std::transform(um.begin(), um.end(), + std::inserter(all_special_ids_, all_special_ids_.end()), + [](const auto& p) { return p.first; }); + } + + en_normalization_ = TryToGetAttributeWithDefault("en_normalization", 0); + skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", 0); + whitespace_token_ = TryToGetAttributeWithDefault("whitespace_token", 0); + bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>")); + eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>")); + unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>")); + } + + std::unordered_map ParseId2String(const std::string& s_attr) { + std::unordered_map result; + result.reserve(s_attr.size() / 4); + std::stringstream ss(s_attr); + + std::string line; + std::string token; + while (std::getline(ss, line, '\n')) { + size_t pos_end = 0; + int64_t v = std::stoll(line, &pos_end); + if (pos_end >= line.size() || line[pos_end] != '\t') { + token.clear(); + } else { + token = line.substr(pos_end + 1); + } + result.emplace(v, token); + } + + return result; + } + + void BuildIdVocab(const std::string& vocab) { + arr_vocab_.reserve(vocab.size() / 2); // give a rough estimation. + + std::u32string u_vocab = ustring(vocab); + std::u32string_view uv_vocab(u_vocab); + size_t last_pos = 0; + + auto ccount = uv_vocab.size(); + for (size_t n = 0; n < ccount; ++n) { + if (uv_vocab[n] == char32_t('\n')) { + std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos); + arr_vocab_.emplace_back(ustring(s_tok)); + last_pos = n + 1; + } else if (n == ccount - 1) { + std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos + 1); + arr_vocab_.emplace_back(ustring(s_tok)); + } + } + + arr_vocab_.shrink_to_fit(); + } + + void Compute(OrtKernelContext* context) { + const OrtValue* ids = ort_.KernelContext_GetInput(context, 0); + const int64_t* p_ids = ort_.GetTensorData(ids); + OrtTensorDimensions ids_dim(ort_, ids); + + if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) { + ORTX_CXX_API_THROW("[BpeDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH); + } + + std::string text; + bool f_special_last = false; + bool f_special = false; + auto count = static_cast(ids_dim.Size()); + + for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) { + const auto token = *(p_ids + tok_idx); + std::string decoded_token; + f_special = all_special_ids_.count(token) ? true : false; + if (skip_special_tokens_ && f_special) { + f_special_last = f_special; + continue; + } + + if (added_tokens_.count(token)) { + const std::string ws = added_tokens_.at(token); + decoded_token = (std::string)ws; + } else { + const auto str = arr_vocab_[token]; + for (auto wchr : str) { + unsigned char uchr = byte_decoder_.at(wchr); + decoded_token.push_back(uchr); + } + } + + if (whitespace_token_ && + f_special && (tok_idx > 0 && !f_special_last)) { + text.push_back(' '); + } + + text.append(decoded_token); + + if (whitespace_token_ && + f_special && tok_idx != count - 1) { + text.push_back(' '); + } + + f_special_last = f_special; + } + + std::vector output_dim = {1}; + std::vector result = {text}; + OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size()); + FillTensorDataString(api_, ort_, context, result, output); + } + + private: + std::string bos_token_; + std::string eos_token_; + std::string unk_token_; + + // Since ORT API doesn't support boolean type in ONNX node attribute, + // all flag attributes here are defined as int64 type to be more explicit. + int64_t en_normalization_ = 0; + int64_t skip_special_tokens_ = 0; + int64_t whitespace_token_ = 0; + std::vector arr_vocab_; + std::map byte_decoder_; + std::map added_tokens_; + std::set all_special_ids_; +}; + +struct CustomOpBpeDecoder : OrtW::CustomOpBase { + const char* GetName() const { + return "BpeDecoder"; + } + + size_t GetInputTypeCount() const { + return 1; + } + + ONNXTensorElementDataType GetInputType(size_t index) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + size_t GetOutputTypeCount() const { + return 1; + } + + ONNXTensorElementDataType GetOutputType(size_t index) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } +}; diff --git a/operators/tokenizer/bpetokenizer.hpp b/operators/tokenizer/bpe_tokenizer.hpp similarity index 100% rename from operators/tokenizer/bpetokenizer.hpp rename to operators/tokenizer/bpe_tokenizer.hpp diff --git a/operators/tokenizer/clip_tokenizer.hpp b/operators/tokenizer/clip_tokenizer.hpp index d5387919..5bc1741d 100644 --- a/operators/tokenizer/clip_tokenizer.hpp +++ b/operators/tokenizer/clip_tokenizer.hpp @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "bpetokenizer.hpp" +#include "bpe_tokenizer.hpp" struct KernelClipBpeTokenizer : BaseKernel { KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info); diff --git a/operators/tokenizer/gpt2_tokenizer.hpp b/operators/tokenizer/gpt2_tokenizer.hpp index 31b2bd2d..68f0d401 100644 --- a/operators/tokenizer/gpt2_tokenizer.hpp +++ b/operators/tokenizer/gpt2_tokenizer.hpp @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "bpetokenizer.hpp" +#include "bpe_tokenizer.hpp" struct KernelBpeTokenizer : BaseKernel { KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info); diff --git a/operators/tokenizer/roberta_tokenizer.hpp b/operators/tokenizer/roberta_tokenizer.hpp index b499b686..404a8702 100644 --- a/operators/tokenizer/roberta_tokenizer.hpp +++ b/operators/tokenizer/roberta_tokenizer.hpp @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "bpetokenizer.hpp" +#include "bpe_tokenizer.hpp" struct KernelRobertaBpeTokenizer : BaseKernel { KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info); diff --git a/operators/tokenizer/sentencepiece_decoder.hpp b/operators/tokenizer/sentencepiece_decoder.hpp index b6f20514..e10e0c5e 100644 --- a/operators/tokenizer/sentencepiece_decoder.hpp +++ b/operators/tokenizer/sentencepiece_decoder.hpp @@ -31,11 +31,10 @@ struct KernelSentencepieceDecoder : BaseKernel { ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH); } - auto count = ids_dim[0]; std::string decoded_string; std::vector output_dim = {1}; std::vector tids; - std::transform(p_ids, p_ids + count, + std::transform(p_ids, p_ids + ids_dim.Size(), std::back_inserter(tids), [](auto _id) { return static_cast(_id); }); auto status = tokenizer_.Decode(tids, &decoded_string); diff --git a/operators/tokenizer/tokenizers.cc b/operators/tokenizer/tokenizers.cc index bdd1764a..1f5c7639 100644 --- a/operators/tokenizer/tokenizers.cc +++ b/operators/tokenizer/tokenizers.cc @@ -5,6 +5,7 @@ #include "gpt2_tokenizer.hpp" #include "clip_tokenizer.hpp" #include "roberta_tokenizer.hpp" +#include "bpe_decoder.hpp" #endif #ifdef ENABLE_SPM_TOKENIZER @@ -33,6 +34,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = LoadCustomOpClasses< , CustomOpBpeTokenizer , CustomOpClipBpeTokenizer , CustomOpRobertaBpeTokenizer + , CustomOpBpeDecoder #endif #ifdef ENABLE_SPM_TOKENIZER diff --git a/operators/tokenizer/wordpiece_tokenizer.cc b/operators/tokenizer/wordpiece_tokenizer.cc index aeaaeb69..661e0a41 100644 --- a/operators/tokenizer/wordpiece_tokenizer.cc +++ b/operators/tokenizer/wordpiece_tokenizer.cc @@ -11,9 +11,7 @@ KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtK std::string vocab_as_string = ort_.KernelInfoGetAttribute(&info, "vocab"); std::string suffix_indicator = ort_.KernelInfoGetAttribute(&info, "suffix_indicator"); std::string unk = ort_.KernelInfoGetAttribute(&info, "unknown_token"); - max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word") - ? ort_.KernelInfoGetAttribute(&info, "max_input_chars_per_word") - : 200; + max_input_chars_per_word_ = TryToGetAttributeWithDefault("max_input_chars_per_word", 200); suffix_indicator_ = ustring(suffix_indicator); unk_token_ = ustring(unk); diff --git a/test/test_bpe_tokenizer.py b/test/test_bpe_tokenizer.py new file mode 100644 index 00000000..c58faadc --- /dev/null +++ b/test/test_bpe_tokenizer.py @@ -0,0 +1,53 @@ +import unittest +import numpy as np + +from transformers import AutoProcessor +from onnxruntime_extensions import PyOrtFunction +from onnxruntime_extensions.cvt import HFTokenizerConverter + + +class TestBpeTokenizer(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.hf_processor = AutoProcessor.from_pretrained( + "openai/whisper-tiny.en") + cls.tokenizer_cvt = HFTokenizerConverter(cls.hf_processor.tokenizer) + return super().setUpClass() + + def test_bpe_tokenizer(self): + fn_tokenizer = PyOrtFunction.from_customop( + "GPT2Tokenizer", + cvt=(self.tokenizer_cvt).bpe_tokenizer) + test_str = " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles" + test_ids = self.hf_processor.tokenizer.encode(test_str) + self.assertTrue(fn_tokenizer(test_ids), test_str) + + def test_en_decoder(self): + special_tokens = False + fn_decoder = PyOrtFunction.from_customop( + "BpeDecoder", + cvt=self.tokenizer_cvt.bpe_decoder, + skip_special_tokens=not special_tokens) + test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt" + test_token_ids = self.hf_processor.tokenizer.encode(test_str) + expected_str = self.hf_processor.tokenizer.decode( + test_token_ids, skip_special_tokens=not special_tokens) + self.assertEqual(fn_decoder(np.asarray(test_token_ids)), expected_str) + + def test_en_decoder_with_special(self): + special_tokens = True + fn_decoder = PyOrtFunction.from_customop( + "BpeDecoder", + cvt=self.tokenizer_cvt.bpe_decoder, + skip_special_tokens=not special_tokens) + test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt" + test_token_ids = self.hf_processor.tokenizer.encode(test_str) + expected_str = self.hf_processor.tokenizer.decode( + test_token_ids, skip_special_tokens=not special_tokens) + actual_str = fn_decoder(np.asarray(test_token_ids)) + self.assertEqual(actual_str[0], expected_str) + + +if __name__ == "__main__": + unittest.main()