add more hf models into converter APIs (#562)

This commit is contained in:
Wenbing Li 2023-09-18 14:38:32 -07:00 коммит произвёл GitHub
Родитель 914509d524
Коммит e899da29d2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 65603 добавлений и 8 удалений

Просмотреть файл

@ -13,6 +13,7 @@ __author__ = "Microsoft"
__all__ = [
'gen_processing_models',
'ort_inference',
'get_library_path',
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
'enable_py_op',
@ -37,7 +38,7 @@ from ._ocos import hook_model_op
from ._ocos import default_opset_domain
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model, ONNXRuntimeError
from .cvt import gen_processing_models
# rename the implementation with a more formal name

Просмотреть файл

@ -152,7 +152,10 @@ _PROCESSOR_DICT = {
default_inputs={'add_eos': [True]}),
"LlamaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True]})
default_inputs={'add_bos': [True]}),
"XLMRobertaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]}),
}
# @formatter:on

Просмотреть файл

@ -170,7 +170,7 @@ class OrtPyFunction:
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
idx += 1
# feed.update(kwargs)
feed.update(kwargs)
return feed
def __call__(self, *args, **kwargs):
@ -180,6 +180,13 @@ class OrtPyFunction:
return outputs[0] if len(outputs) == 1 else tuple(outputs)
def ort_inference(model, *args, cpu_only=True, **kwargs):
"""
Run an ONNX model with ORT where args are inputs and return values are outputs.
"""
return OrtPyFunction(model, cpu_only=cpu_only)(*args, **kwargs)
def optimize_model(model_or_file, output_file):
sess_options = OrtPyFunction().get_ort_session_options()
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

Просмотреть файл

@ -10,7 +10,7 @@ cvt.py: Processing Graph Converter and Generator
from typing import Union
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
from ._ortapi2 import make_onnx_model
from ._ortapi2 import make_onnx_model, SingleOpGraph
_is_torch_available = False
@ -22,6 +22,9 @@ except ImportError:
WhisperDataProcGraph = None
_PRE_POST_PAIR = {'TrieTokenizer': "TrieDetokenizer"}
def gen_processing_models(processor: Union[str, object],
pre_kwargs: dict = None,
post_kwargs: dict = None,
@ -52,8 +55,21 @@ def gen_processing_models(processor: Union[str, object],
"""
if pre_kwargs is None and post_kwargs is None:
raise ValueError("Either pre_kwargs or post_kwargs should be provided. None means no processing")
if isinstance(processor, str):
g_pre, g_post = (None, None)
if pre_kwargs:
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
if post_kwargs:
if pre_kwargs is None:
cls_name = processor
else:
if processor not in _PRE_POST_PAIR:
raise RuntimeError(f"Cannot locate the post processing operator name from {processor}")
cls_name = _PRE_POST_PAIR[processor]
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
cls_name = processor if isinstance(processor, str) else type(processor).__name__
cls_name = type(processor).__name__
if cls_name == "WhisperProcessor":
if WhisperDataProcGraph is None:
raise ValueError("The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")

Просмотреть файл

@ -6,7 +6,7 @@
void image_reader(const ortc::Tensor<std::string>& input,
ortc::Tensor<uint8_t>& output) {
auto& input_data_dimensions = input.Shape();
int n = input_data_dimensions[0];
auto n = input_data_dimensions[0];
if (n != 1) {
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -4,7 +4,7 @@ import unittest
import numpy as np
from transformers import AutoTokenizer
from onnxruntime_extensions import OrtPyFunction, gen_processing_models
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference, util
class TestAutoTokenizer(unittest.TestCase):
@ -69,6 +69,45 @@ class TestAutoTokenizer(unittest.TestCase):
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids, actual_ids)
def test_xmlroberta_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
# TODO: if there is <unk> in text, the result is not matched.
text = (
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
" add words that should not exsist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids)
def test_trie_tokenizer(self):
vocab_file = util.get_test_data_file("data", "rwkv_vocab_v20230424.txt")
vocab_data = util.read_file(vocab_file, 'rb')
tok, detok = gen_processing_models("TrieTokenizer",
pre_kwargs={'vocab': vocab_data},
post_kwargs={'vocab': vocab_data})
text = ["that dog is so cute"]
ids = ort_inference(tok, text)
det_text = ort_inference(detok, ids)
self.assertEqual(text, det_text)
def test_microsoft_ph1(self):
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto", use_fast=False)
code = '''```python
def print_prime(n):
"""
Print all primes between 1 and n
"""'''
ids = tokenizer(code, return_tensors="np", return_attention_mask=False)
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [code])
self.assertEqual(len(ids['input_ids'].shape), len(actual_ids.shape))
# TODO: not matched.
# np.testing.assert_array_equal(ids['input_ids'], actual_ids)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1 +1 @@
0.9.0
0.10.0