add more hf models into converter APIs (#562)
This commit is contained in:
Родитель
914509d524
Коммит
e899da29d2
|
@ -13,6 +13,7 @@ __author__ = "Microsoft"
|
|||
|
||||
__all__ = [
|
||||
'gen_processing_models',
|
||||
'ort_inference',
|
||||
'get_library_path',
|
||||
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
|
||||
'enable_py_op',
|
||||
|
@ -37,7 +38,7 @@ from ._ocos import hook_model_op
|
|||
from ._ocos import default_opset_domain
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
from .cvt import gen_processing_models
|
||||
|
||||
# rename the implementation with a more formal name
|
||||
|
|
|
@ -152,7 +152,10 @@ _PROCESSOR_DICT = {
|
|||
default_inputs={'add_eos': [True]}),
|
||||
"LlamaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
|
||||
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True]})
|
||||
default_inputs={'add_bos': [True]}),
|
||||
"XLMRobertaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
|
||||
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]}),
|
||||
}
|
||||
# @formatter:on
|
||||
|
||||
|
|
|
@ -170,7 +170,7 @@ class OrtPyFunction:
|
|||
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||
idx += 1
|
||||
|
||||
# feed.update(kwargs)
|
||||
feed.update(kwargs)
|
||||
return feed
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
@ -180,6 +180,13 @@ class OrtPyFunction:
|
|||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
|
||||
def ort_inference(model, *args, cpu_only=True, **kwargs):
|
||||
"""
|
||||
Run an ONNX model with ORT where args are inputs and return values are outputs.
|
||||
"""
|
||||
return OrtPyFunction(model, cpu_only=cpu_only)(*args, **kwargs)
|
||||
|
||||
|
||||
def optimize_model(model_or_file, output_file):
|
||||
sess_options = OrtPyFunction().get_ort_session_options()
|
||||
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
|
|
|
@ -10,7 +10,7 @@ cvt.py: Processing Graph Converter and Generator
|
|||
from typing import Union
|
||||
|
||||
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
|
||||
from ._ortapi2 import make_onnx_model
|
||||
from ._ortapi2 import make_onnx_model, SingleOpGraph
|
||||
|
||||
|
||||
_is_torch_available = False
|
||||
|
@ -22,6 +22,9 @@ except ImportError:
|
|||
WhisperDataProcGraph = None
|
||||
|
||||
|
||||
_PRE_POST_PAIR = {'TrieTokenizer': "TrieDetokenizer"}
|
||||
|
||||
|
||||
def gen_processing_models(processor: Union[str, object],
|
||||
pre_kwargs: dict = None,
|
||||
post_kwargs: dict = None,
|
||||
|
@ -52,8 +55,21 @@ def gen_processing_models(processor: Union[str, object],
|
|||
"""
|
||||
if pre_kwargs is None and post_kwargs is None:
|
||||
raise ValueError("Either pre_kwargs or post_kwargs should be provided. None means no processing")
|
||||
if isinstance(processor, str):
|
||||
g_pre, g_post = (None, None)
|
||||
if pre_kwargs:
|
||||
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
|
||||
if post_kwargs:
|
||||
if pre_kwargs is None:
|
||||
cls_name = processor
|
||||
else:
|
||||
if processor not in _PRE_POST_PAIR:
|
||||
raise RuntimeError(f"Cannot locate the post processing operator name from {processor}")
|
||||
cls_name = _PRE_POST_PAIR[processor]
|
||||
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
|
||||
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
|
||||
|
||||
cls_name = processor if isinstance(processor, str) else type(processor).__name__
|
||||
cls_name = type(processor).__name__
|
||||
if cls_name == "WhisperProcessor":
|
||||
if WhisperDataProcGraph is None:
|
||||
raise ValueError("The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
void image_reader(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& input_data_dimensions = input.Shape();
|
||||
int n = input_data_dimensions[0];
|
||||
auto n = input_data_dimensions[0];
|
||||
if (n != 1) {
|
||||
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -4,7 +4,7 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from onnxruntime_extensions import OrtPyFunction, gen_processing_models
|
||||
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference, util
|
||||
|
||||
|
||||
class TestAutoTokenizer(unittest.TestCase):
|
||||
|
@ -69,6 +69,45 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
def test_xmlroberta_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
|
||||
# TODO: if there is <unk> in text, the result is not matched.
|
||||
text = (
|
||||
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
|
||||
" add words that should not exsist and be tokenized to , such as saoneuhaoesuth")
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})
|
||||
actual_ids, *_ = ort_inference(ort_tok, [text])
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_trie_tokenizer(self):
|
||||
vocab_file = util.get_test_data_file("data", "rwkv_vocab_v20230424.txt")
|
||||
vocab_data = util.read_file(vocab_file, 'rb')
|
||||
tok, detok = gen_processing_models("TrieTokenizer",
|
||||
pre_kwargs={'vocab': vocab_data},
|
||||
post_kwargs={'vocab': vocab_data})
|
||||
text = ["that dog is so cute"]
|
||||
ids = ort_inference(tok, text)
|
||||
det_text = ort_inference(detok, ids)
|
||||
self.assertEqual(text, det_text)
|
||||
|
||||
def test_microsoft_ph1(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto", use_fast=False)
|
||||
code = '''```python
|
||||
def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""'''
|
||||
|
||||
ids = tokenizer(code, return_tensors="np", return_attention_mask=False)
|
||||
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
|
||||
actual_ids, *_ = ort_inference(ort_tok, [code])
|
||||
self.assertEqual(len(ids['input_ids'].shape), len(actual_ids.shape))
|
||||
# TODO: not matched.
|
||||
# np.testing.assert_array_equal(ids['input_ids'], actual_ids)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.9.0
|
||||
0.10.0
|
Загрузка…
Ссылка в новой задаче