Add more tests for pre-processing C APIs (#793)

* initial api for tokenizer

* More fixings and test data refinement

* add a simple wrapper for pre-processing APIs

* fix the test issues

* test if the tokenizer is spm based

* fix the failed test cases

* json pointer does not work
This commit is contained in:
Wenbing Li 2024-08-21 16:48:39 -07:00 коммит произвёл GitHub
Родитель 85ffb94169
Коммит 8f2c35fad0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
15 изменённых файлов: 282 добавлений и 225 удалений

Просмотреть файл

@ -147,6 +147,7 @@ class CmdBuildCMakeExt(_build_ext):
self.no_azure = None
self.no_opencv = None
self.cc_debug = None
self.pp_api = None
self.cuda_archs = None
self.ort_pkg_dir = None
@ -210,6 +211,9 @@ class CmdBuildCMakeExt(_build_ext):
'-DOCOS_ENABLE_CV2=OFF',
'-DOCOS_ENABLE_VISION=OFF']
if self.pp_api:
cmake_args += ['-DOCOS_ENABLE_C_API=ON']
if self.no_azure is not None:
azure_flag = "OFF" if self.no_azure == 1 else "ON"
cmake_args += ['-DOCOS_ENABLE_AZURE=' + azure_flag]

Просмотреть файл

@ -11,11 +11,11 @@ class ustring : public std::u32string {
public:
ustring() = default;
explicit ustring(const char* str) { assign(FromUTF8(str)); }
explicit ustring(const char* str) { assign(std::move(FromUTF8(str))); }
explicit ustring(const std::string& str) { assign(FromUTF8(str)); }
explicit ustring(const std::string& str) { assign(std::move(FromUTF8(str))); }
explicit ustring(const std::string_view& str) { assign(FromUTF8(str)); }
explicit ustring(const std::string_view& str) { assign(std::move(FromUTF8(str))); }
explicit ustring(const char32_t* str) : std::u32string(str) {}
@ -76,11 +76,15 @@ class ustring : public std::u32string {
}
}
static bool ValidateUTF8(const std::string& data) {
// return a negative value for the first invalid utf8 char position,
// otherwise the position of the terminating null character, which is the end of the string.
static ptrdiff_t ValidateUTF8(const std::string& data) {
const unsigned char* s = reinterpret_cast<const unsigned char*>(data.c_str());
const unsigned char* s_begin = s;
const unsigned char* s_end = s + data.size();
if (*s_end != '\0')
return false;
return 0;
while (*s) {
if (*s < 0x80)
@ -89,17 +93,17 @@ class ustring : public std::u32string {
else if ((s[0] & 0xe0) == 0xc0) {
/* 110XXXXx 10xxxxxx */
if (s + 1 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[0] & 0xfe) == 0xc0) /* overlong? */
return false;
return s_begin - s;
else
s += 2;
} else if ((s[0] & 0xf0) == 0xe0) {
/* 1110XXXX 10Xxxxxx 10xxxxxx */
if (s + 2 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
@ -107,27 +111,27 @@ class ustring : public std::u32string {
(s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */
(s[0] == 0xef && s[1] == 0xbf &&
(s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */
return false;
return s_begin - s;
else
s += 3;
} else if ((s[0] & 0xf8) == 0xf0) {
/* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */
if (s + 3 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
(s[3] & 0xc0) != 0x80 ||
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */
(s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */
return false;
return s_begin - s;
else
s += 4;
} else
return false;
return s_begin - s;
}
return true;
return s - s_begin;
}
private:

Просмотреть файл

@ -18,4 +18,13 @@ Most APIs accept raw data inputs such as audio, image compressed binary formats,
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extraction.cc#L16).
NB: If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.
**NB:** If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.
There is a simple Python wrapper on these C API in [pp_api](../onnxruntime_extensions/pp_api.py), which can have a easy access these APIs in Python code like
```Python
from onnxruntime_extensions.pp_api import Tokenizer
# the name can be the same one used by Huggingface transformers.AutoTokenizer
pp_tok = Tokenizer('google/gemma-2-2b')
print(pp_tok.tokenize("what are you? \n 给 weiss ich, über was los ist \n"))
```

Просмотреть файл

@ -531,15 +531,6 @@ expect(node, inputs=[inputs],
</details>
### BlingFireSentenceBreaker
TODO
### BpeTokenizer
TODO
## String operators
### StringEqual

Просмотреть файл

@ -16,6 +16,7 @@ The package contains all custom operators and some Python scripts to manipulate
- no-azure: disable AzureOp kernel build in Python package.
- no-opencv: disable operators based on OpenCV in build.
- cc-debug: generate debug info for extensions binaries and disable C/C++ compiler optimization.
- pp_api: enable pre-processing C ABI Python wrapper, `from onnxruntime_extensions.pp_api import *`
- cuda-archs: specify the CUDA architectures(like 70, 85, etc.), and the multiple values can be combined with semicolon. The default value is nvidia-smi util output of GPU-0
- ort\_pkg\_dir: specify ONNXRuntime package directory the extension project is depending on. This is helpful if you want to use some ONNXRuntime latest function which has not been involved in the official build

Просмотреть файл

@ -3,11 +3,69 @@
# license information.
###############################################################################
import os
from . import _extensions_pydll as _C
if not hasattr(_C, "create_processor"):
raise ImportError("onnxruntime_extensions is not built with pre-processing API")
if not hasattr(_C, "delete_object"):
raise ImportError(
"onnxruntime_extensions is not built with pre-processing C API"
"To enable it, please build the package with --ortx-user-option=pp_api")
create_processor = _C.create_processor
load_images = _C.load_images
image_pre_process = _C.image_pre_process
tensor_result_get_at = _C.tensor_result_get_at
create_tokenizer = _C.create_tokenizer
batch_tokenize = _C.batch_tokenize
batch_detokenize = _C.batch_detokenize
delete_object = _C.delete_object
class Tokenizer:
def __init__(self, tokenizer_dir):
if os.path.isdir(tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)
else:
try:
from transformers.utils import cached_file
resolved_full_file = cached_file(
tokenizer_dir, "tokenizer.json")
resolved_config_file = cached_file(
tokenizer_dir, "tokenizer_config.json")
except ImportError:
raise ValueError(
f"Directory '{tokenizer_dir}' not found and transformers is not available")
if not os.path.exists(resolved_full_file):
raise FileNotFoundError(
f"Downloaded HF file '{resolved_full_file}' cannot be found")
if (os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file)):
raise FileNotFoundError(
f"Downloaded HF files '{resolved_full_file}' and '{resolved_config_file}' are not in the same directory")
tokenizer_dir = os.path.dirname(resolved_full_file)
self.tokenizer = create_tokenizer(tokenizer_dir)
def tokenize(self, text):
return batch_tokenize(self.tokenizer, [text])[0]
def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]
def __del__(self):
if delete_object and self.tokenizer:
delete_object(self.tokenizer)
self.tokenizer = None
class ImageProcessor:
def __init__(self, processor_json):
self.processor = create_processor(processor_json)
def pre_process(self, images):
return image_pre_process(self.processor, images)
def __del__(self):
if delete_object and self.processor:
delete_object(self.processor)
self.processor = None

Просмотреть файл

@ -27,11 +27,6 @@ static bool IsBosEosRequired(const std::string& model_name) {
return model_name != kModel_GPT2 && model_name != kModel_CodeGen;
}
static bool IsSpmModel(const std::string& model_name) {
return model_name == kModel_Llama ||
model_name == kModel_Gemma;
}
std::string BpeModelConf::GetSpecialTokens() const {
std::string special_tokens = unk_token_; // unk_token_ is required
auto add_token = [](std::string& sp, const char* tok) {
@ -145,7 +140,7 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
merges_stream,
bpe_conf_.get().unk_token_,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
bpe_conf_.get().spm_model_);
if (!status.IsOk()) {
return (OrtStatusPtr)status;
}
@ -454,7 +449,7 @@ OrtxStatus KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
}
auto tok_fun = &KernelBpeTokenizer::Tokenize;
if (IsSpmModel(ModelName())) {
if (bpe_conf_.get().spm_model_) {
tok_fun = &KernelBpeTokenizer::SpmTokenize;
}
@ -556,7 +551,8 @@ static const auto kSpmConfiguration = BpeModelConf{
"<unk>", // unk_token
"<s>", // bos_token
"</s>", // eos_token
""}; // pad_token
"", // pad_token
true};
SpmTokenizer::SpmTokenizer()
: KernelBpeTokenizer(kSpmConfiguration) {}
@ -718,6 +714,24 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
module_ifs >> tok_json;
} else {
ifs >> tok_json;
// auto decoders_node = tok_json.find("/decoder/decoders"_json_pointer);
auto decoders_node = tok_json.find("decoder");
if (decoders_node != tok_json.end()) {
decoders_node = decoders_node->find("decoders");
}
if (decoders_node->is_array()) {
for(auto step = decoders_node->begin(); step != decoders_node->end(); ++step) {
std::string type = step->value("type", "");
if (type == "Replace") {
std::string target = step->value("/pattern/String"_json_pointer, "");
if (target == "\xe2\x96\x81") {
json_conf_.spm_model_ = true;
break;
}
}
}
}
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
@ -725,8 +739,8 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
}

Просмотреть файл

@ -20,6 +20,7 @@ struct BpeModelConf {
const char* eos_token_{"<|endoftext|>"};
const char* pad_token_{nullptr};
bool spm_model_{};
std::string GetSpecialTokens() const;
};
@ -108,41 +109,23 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : public KernelBpeTokenizer {
public:
JsonFastTokenizer();
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
bool IsSpmModel() const { return json_conf_.spm_model_; }
bool tiktoken_ = false;
std::string unicode_byte_encoder_[256] = {};
private:
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
private:
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};
class TikTokenizer : KernelBpeTokenizer {
public:
TikTokenizer();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
private:
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
std::string unicode_byte_encoder_[256] = {};
};

Просмотреть файл

@ -30,6 +30,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
bos_token_ = tok_config.bos_token_;
eos_token_ = tok_config.eos_token_;
unk_token_ = tok_config.unk_token_;
spm_model_ = encoder.IsSpmModel();
const auto& a_toks = encoder.GetAddedTokens();
for (const auto& tok : a_toks) {
@ -122,10 +123,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
return {};
}
static bool IsSpmTokenizer(const std::string& tok_class) {
return tok_class == "GemmaTokenizer" || tok_class == "LlamaTokenizer";
}
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
auto bpe_state = *state;
std::unique_ptr<BPEDecoderState> bpe_state_ptr;
@ -138,9 +135,9 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
bool f_special = bpe_state->f_special_last; // [Spm]Id2Token needs the last state
bool f_special_last = bpe_state->f_special_last;
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
? SpmId2Token(id, token, f_special)
: Id2Token(id, token, true /* tok_config_.skip_special_tokens_ */, f_special);
auto status = spm_model_
? SpmId2Token(id, token, f_special)
: Id2Token(id, token, true /* tok_config_.skip_special_tokens_ */, f_special);
if (status.IsOk()) {
if (bpe_state_ptr) {
@ -167,7 +164,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
if (utf8_len <= s_utf8.size() - i) {
utf8_all_len += utf8_len;
auto _t = s_utf8.substr(i, utf8_len);
token += ustring::ValidateUTF8(_t) ? _t : "";
token += ustring::ValidateUTF8(_t) > 0 ? _t : "";
}
}
@ -200,9 +197,9 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
for (size_t tok_idx = 0; tok_idx < seq_len; ++tok_idx) {
const auto id = ort_extensions::narrow<extTokenId_t>(*(p_ids + tok_idx));
std::string decoded_token;
auto status = IsSpmTokenizer(tok_config_->tokenizer_class_)
? SpmId2Token(id, decoded_token, f_special_last)
: Id2Token(id, decoded_token, true, f_special_last);
auto status = spm_model_
? SpmId2Token(id, decoded_token, f_special_last)
: Id2Token(id, decoded_token, true, f_special_last);
if (!status.IsOk()) {
return status;
@ -225,6 +222,11 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
text.pop_back();
}
ptrdiff_t z = ustring::ValidateUTF8(text);
if (z <= 0) {
text = text.substr(0, -z);
}
decoded_strings.emplace_back(std::move(text));
p_ids += seq_len;
}
@ -251,7 +253,9 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
}
private:
extTokenId_t eos_token_id_{0};
bool add_dummy_prefix_ = false;
bool spm_model_{};
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
};

Просмотреть файл

@ -178,7 +178,7 @@ struct KernelTrieDetokenizer {
ids.push_back(ort_extensions::narrow<int>(p_ids[n * ids_dim[1] + i]));
}
auto raw_string = tokenizer->decodeBytes(ids);
if (ustring::ValidateUTF8(raw_string)) {
if (ustring::ValidateUTF8(raw_string) > 0) {
output[n] = raw_string;
} else {
output[n] = "\ufffd"; // bad utf-8 string

Просмотреть файл

@ -9,6 +9,7 @@
#include <thread>
#include "ortx_utils.h"
#include "ortx_tokenizer.h"
#include "ortx_processor.h"
#include "pykernel.h"
@ -68,47 +69,130 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
},
"Preprocess images.");
m.def("tensor_result_get_at", [](std::uintptr_t result_h, size_t index) {
OrtxTensorResult* result = reinterpret_cast<OrtxTensorResult*>(result_h);
OrtxTensor* tensor{};
auto err = OrtxTensorResultGetAt(result, index, &tensor);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get tensor") + OrtxGetLastErrorMessage());
}
m.def(
"tensor_result_get_at",
[](std::uintptr_t result_h, size_t index) {
OrtxTensorResult* result = reinterpret_cast<OrtxTensorResult*>(result_h);
OrtxTensor* tensor{};
auto err = OrtxTensorResultGetAt(result, index, &tensor);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get tensor") + OrtxGetLastErrorMessage());
}
extDataType_t tensor_type;
extDataType_t tensor_type;
OrtxGetTensorType(tensor, &tensor_type);
const int64_t* shape{};
size_t num_dims;
const void* data{};
size_t elem_size = 0;
if (tensor_type == extDataType_t::kOrtxInt64 || tensor_type == extDataType_t::kOrtxFloat) {
OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
elem_size = 4;
if (tensor_type == extDataType_t::kOrtxInt64) {
elem_size = 8;
}
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("Failed to get tensor type");
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("unsupported tensor type");
}
OrtxGetTensorType(tensor, &tensor_type);
const int64_t* shape{};
size_t num_dims;
const void* data{};
size_t elem_size = 0;
if (tensor_type == extDataType_t::kOrtxInt64 || tensor_type == extDataType_t::kOrtxFloat) {
OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
elem_size = 4;
if (tensor_type == extDataType_t::kOrtxInt64) {
elem_size = 8;
}
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("Failed to get tensor type");
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("unsupported tensor type");
}
std::vector<std::size_t> npy_dims;
for (auto n = num_dims - num_dims; n < num_dims; ++n) {
npy_dims.push_back(shape[n]);
}
py::array obj{};
std::vector<std::size_t> npy_dims;
for (auto n = num_dims - num_dims; n < num_dims; ++n) {
npy_dims.push_back(shape[n]);
}
py::array obj{};
if (tensor_type == extDataType_t::kOrtxFloat) {
obj = py::array_t<float>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxInt64) {
obj = py::array_t<int64_t>(npy_dims);
}
if (tensor_type == extDataType_t::kOrtxFloat) {
obj = py::array_t<float>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxInt64) {
obj = py::array_t<int64_t>(npy_dims);
}
void* out_ptr = obj.mutable_data();
memcpy(out_ptr, data, NumOfElement(npy_dims) * elem_size);
return obj;
}, "Get tensor at index.");
void* out_ptr = obj.mutable_data();
memcpy(out_ptr, data, NumOfElement(npy_dims) * elem_size);
return obj;
},
"Get tensor at index.");
m.def(
"create_tokenizer",
[](std::string tokenizer_def_json) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
"Create a tokenizer.");
m.def(
"batch_tokenize",
[](std::uintptr_t h, const std::vector<std::string>& inputs) -> std::vector<std::vector<int64_t>> {
std::vector<std::vector<int64_t>> output;
OrtxTokenizer* tokenizer = reinterpret_cast<OrtxTokenizer*>(h);
OrtxTokenId2DArray* tid_output = nullptr;
std::vector<const char*> cs_inputs;
for (const auto& input : inputs) {
cs_inputs.push_back(input.c_str());
}
auto err = OrtxTokenize(tokenizer, cs_inputs.data(), inputs.size(), &tid_output);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to tokenize") + OrtxGetLastErrorMessage());
}
for (size_t i = 0; i < inputs.size(); ++i) {
const extTokenId_t* t2d{};
size_t length{};
err = OrtxTokenId2DArrayGetItem(tid_output, i, &t2d, &length);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get token id") + OrtxGetLastErrorMessage());
}
output.push_back(std::vector<int64_t>(t2d, t2d + length));
}
OrtxDisposeOnly(tid_output);
return output;
},
"Batch tokenize.");
m.def(
"batch_detokenize",
[](std::uintptr_t h, const std::vector<std::vector<int64_t>>& inputs) -> std::vector<std::string> {
std::vector<std::string> result;
OrtxTokenizer* tokenizer = reinterpret_cast<OrtxTokenizer*>(h);
OrtxStringArray* output = nullptr;
for (size_t i = 0; i < inputs.size(); ++i) {
std::vector<extTokenId_t> input;
input.reserve(inputs[i].size());
std::transform(inputs[i].begin(), inputs[i].end(), std::back_inserter(input),
[](int64_t v) { return static_cast<extTokenId_t>(v); });
auto err = OrtxDetokenize1D(tokenizer, input.data(), input.size(), &output);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to detokenize") + OrtxGetLastErrorMessage());
}
size_t length;
err = OrtxStringArrayGetBatch(output, &length);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get batch size") + OrtxGetLastErrorMessage());
}
for (size_t i = 0; i < length; ++i) {
const char* item;
err = OrtxStringArrayGetItem(output, i, &item);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get item") + OrtxGetLastErrorMessage());
}
result.push_back(item);
}
OrtxDisposeOnly(output);
}
return result;
},
"Batch detokenize.");
m.def(
"delete_object", [](std::uintptr_t h) { OrtxDisposeOnly(reinterpret_cast<OrtxObject*>(h)); },
"Delete the object created by C API.");
}

Просмотреть файл

@ -1,76 +0,0 @@
import os
import tempfile
from PIL import Image
from transformers import AutoProcessor
from onnxruntime_extensions.pp_api import create_processor, load_images, image_pre_process, tensor_result_get_at
import numpy as np
def regen_image(arr):
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
# Reverse normalization
array = arr * std + mean
# Clip the values to [0, 1] range
array = np.clip(array, 0, 1)
# Convert to [0, 255] range and uint8 type
array = (array * 255).astype(np.uint8)
# Convert NumPy array to PIL Image
image = Image.fromarray(array)
return image
test_image = "test/data/processor/passport.png"
# test_image = "/temp/passport_s.png"
# test_image = "/temp/passport_s2.png"
model_id = "microsoft/Phi-3-vision-128k-instruct"
processor = create_processor("test/data/processor/phi_3_image.json")
images = load_images([test_image])
c_out = image_pre_process(processor, images)
# print(tensor_result_get_at(c_out, 0))
# print(tensor_result_get_at(c_out, 1))
image = Image.open(test_image)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
messages = [
{"role": "user", "content": "<|image_1|>\nWhat is shown in this image?"},
{"role": "assistant", "content": "The chart displays the percentage of respondents who agree with various statements about their preparedness for meetings. It shows five categories: 'Having clear and pre-defined goals for meetings', 'Knowing where to find the information I need for a meeting', 'Understanding my exact role and responsibilities when I'm invited', 'Having tools to manage admin tasks like note-taking or summarization', and 'Having more focus time to sufficiently prepare for meetings'. Each category has an associated bar indicating the level of agreement, measured on a scale from 0% to 100%."},
{"role": "user", "content": "Provide insightful questions to spark discussion."}
]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt")
# print(inputs["pixel_values"].numpy())
# print(inputs["image_sizes"])
np.testing.assert_allclose(
inputs["image_sizes"].numpy(), tensor_result_get_at(c_out, 1))
# np.testing.assert_allclose(inputs["pixel_values"].numpy(), tensor_result_get_at(c_out, 0), rtol=1e-1)
if os.path.exists("/temp"):
temp_dir = "/temp"
else:
temp_dir = tempfile.mkdtemp()
print(f"Created temp dir: {temp_dir}")
for i in range(17):
expected = inputs["pixel_values"].numpy()[0, i]
actual = tensor_result_get_at(c_out, 0)[0, i]
e_image = regen_image(expected.transpose(1, 2, 0))
a_image = regen_image(actual.transpose(1, 2, 0))
e_image.save(f"{temp_dir}/e_{i}.png")
a_image.save(f"{temp_dir}/a_{i}.png")
try:
np.testing.assert_allclose(inputs["pixel_values"].numpy(
)[0, i], tensor_result_get_at(c_out, 0)[0, i], rtol=1e-2)
except AssertionError as e:
print(str(e))

Просмотреть файл

@ -12,20 +12,6 @@
using namespace ort_extensions;
std::vector<float> ReadArrayFromFile(const std::string& filename) {
std::ifstream inFile(filename, std::ios::binary | std::ios::ate);
if (!inFile) {
throw std::runtime_error("Cannot open file for reading.");
}
std::streamsize fileSize = inFile.tellg();
inFile.seekg(0, std::ios::beg);
std::vector<float> array(fileSize / sizeof(float));
if (!inFile.read(reinterpret_cast<char*>(array.data()), fileSize)) {
throw std::runtime_error("Error reading file.");
}
return array;
}
TEST(ProcessorTest, TestPhi3VImageProcessing) {
auto [input_data, n_data] = ort_extensions::LoadRawImages(
@ -47,20 +33,6 @@ TEST(ProcessorTest, TestPhi3VImageProcessing) {
ASSERT_EQ(image_sizes->Shape(), std::vector<int64_t>({3, 2}));
ASSERT_EQ(num_img_tokens->Shape(), std::vector<int64_t>({3, 1}));
if (std::filesystem::is_directory("data2/processor")) {
// the test data was dumped in this way
// {
// std::ofstream outFile("data2/processor/img_proc_pixel_values.bin", std::ios::binary);
// outFile.write(reinterpret_cast<const char*>(array.data()), array.size() * sizeof(float));
// }
auto expected_output = ReadArrayFromFile("data2/processor/img_proc_pixel_values.bin");
ASSERT_EQ(pixel_values->NumberOfElement(), expected_output.size());
for (size_t i = 0; i < expected_output.size(); i++) {
ASSERT_NEAR(pixel_values->Data()[i], expected_output[i], 1e-3);
}
}
// compare the image sizes
for (size_t i = 0; i < 3; i++) {
ASSERT_EQ(image_sizes->Data()[i * 2], expected_image_size[i * 2]);

Просмотреть файл

@ -13,7 +13,7 @@ TEST(strings, std_regex_test) {
"\U0001f300-\U0001f5ff\U0001f900-\U0001f9ff\U0001fa70-\U0001faff"
"\U0001f680-\U0001f6ff]");
std::string test =u8"abcde😀🔍🦑😁🔍🎉😂🤣";
std::string test = u8"abcde😀🔍🦑😁🔍🎉😂🤣";
auto result = std::regex_replace(test, regex, "");
std::cout << test << std::endl;
std::cout << result << std::endl;

Просмотреть файл

@ -33,12 +33,21 @@ class TestAutoTokenizer(unittest.TestCase):
def test_phi_3_mini(self):
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct", use_fast=True)
text = "what are you? \n 给 weiss ich, über was los ist \n"
ids = tokenizer.encode(text, return_tensors="np")
text = ["what are you? \n 给 weiss ich, über was los ist \n",
"@? \n was los ist \n",
"Qué dijiste? \n über 给 ば was los ist im Mannschaft ц \n",
"明天雷阵雨气温26度。"]
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids[0][1:])
actual_ids, *_ = ort_inference(ort_tok, text)
for n in range(len(actual_ids)):
expected_ids = tokenizer.encode(text[n], return_tensors="np")
try:
np.testing.assert_array_equal(
expected_ids[0], actual_ids[n][1:expected_ids.shape[1] + 1])
except AssertionError:
print("index is ", n)
raise
if __name__ == '__main__':