Add Bert tokenizer in the supported model list and code refinement (#503)
* Add Bert tokenizer in the supported model list and the related code refinement * utest fix
This commit is contained in:
Родитель
6209804ee9
Коммит
922b7cc387
|
@ -4,24 +4,42 @@
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The entry point to onnxruntime-extensions package.
|
The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and
|
||||||
|
post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python.
|
||||||
|
This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__author__ = "Microsoft"
|
__author__ = "Microsoft"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'gen_processing_models',
|
||||||
|
'get_library_path',
|
||||||
|
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
|
||||||
|
'enable_py_op',
|
||||||
|
'expand_onnx_inputs',
|
||||||
|
'hook_model_op',
|
||||||
|
'default_opset_domain',
|
||||||
|
'OrtPyFunction', 'PyOrtFunction',
|
||||||
|
'optimize_model',
|
||||||
|
'make_onnx_model',
|
||||||
|
'ONNXRuntimeError',
|
||||||
|
'hash_64',
|
||||||
|
'__version__',
|
||||||
|
]
|
||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from ._ocos import get_library_path # noqa
|
from ._ocos import get_library_path
|
||||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
from ._ocos import Opdef, PyCustomOpDef
|
||||||
from ._ocos import hash_64 # noqa
|
from ._ocos import hash_64
|
||||||
from ._ocos import enable_py_op # noqa
|
from ._ocos import enable_py_op
|
||||||
from ._ocos import expand_onnx_inputs # noqa
|
from ._ocos import expand_onnx_inputs
|
||||||
from ._ocos import hook_model_op # noqa
|
from ._ocos import hook_model_op
|
||||||
from ._ocos import default_opset_domain # noqa
|
from ._ocos import default_opset_domain
|
||||||
from ._cuops import * # noqa
|
from ._cuops import * # noqa
|
||||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||||
from .cvt import gen_processing_models
|
from .cvt import gen_processing_models
|
||||||
|
|
||||||
|
# rename the implementation with a more formal name
|
||||||
onnx_op = Opdef.declare
|
onnx_op = Opdef.declare
|
||||||
PyOp = PyCustomOpDef
|
PyOp = PyCustomOpDef
|
||||||
|
|
|
@ -253,7 +253,9 @@ class BertTokenizer(CustomOp):
|
||||||
def serialize_attr(cls, attrs):
|
def serialize_attr(cls, attrs):
|
||||||
attrs_data = {}
|
attrs_data = {}
|
||||||
for k_, v_ in attrs.items():
|
for k_, v_ in attrs.items():
|
||||||
if k_ == 'vocab_file':
|
if k_ == 'vocab':
|
||||||
|
attrs_data['vocab_file'] = v_
|
||||||
|
elif k_ == 'vocab_file':
|
||||||
with open(v_, "r", encoding='utf-8') as model_file:
|
with open(v_, "r", encoding='utf-8') as model_file:
|
||||||
lines = model_file.readlines()
|
lines = model_file.readlines()
|
||||||
attrs_data[k_] = '\n'.join(lines)
|
attrs_data[k_] = '\n'.join(lines)
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
# Licensed under the MIT License. See License.txt in the project root for
|
# Licensed under the MIT License. See License.txt in the project root for
|
||||||
# license information.
|
# license information.
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
class PyCustomOpDef:
|
class PyCustomOpDef:
|
||||||
undefined: int = ...
|
undefined: int = ...
|
||||||
|
@ -21,6 +23,8 @@ class PyCustomOpDef:
|
||||||
dt_complex64: int = ...
|
dt_complex64: int = ...
|
||||||
dt_complex128: int = ...
|
dt_complex128: int = ...
|
||||||
dt_bfloat16: int = ...
|
dt_bfloat16: int = ...
|
||||||
|
def install_hooker(self, invocation_handler: Callable) -> None:
|
||||||
|
...
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,9 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import onnx
|
import onnx
|
||||||
import numpy as np
|
from numpy import array as nparray
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from collections import namedtuple
|
from collections import namedtuple, OrderedDict
|
||||||
|
|
||||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||||
from .util import read_file
|
from .util import read_file
|
||||||
|
@ -32,6 +32,25 @@ class HFTokenizerConverter(CustomOpConverter):
|
||||||
attrs.update(**kwargs)
|
attrs.update(**kwargs)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
def bert_tokenizer(self, **kwargs):
|
||||||
|
hf_bert_tokenizer = self.tokenizer
|
||||||
|
# has to be sorted since the id of token was generated automatically.
|
||||||
|
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||||
|
vocab = '\n'.join(ordered_vocab.keys())
|
||||||
|
attrs = dict(vocab=vocab)
|
||||||
|
init_kwargs = hf_bert_tokenizer.init_kwargs
|
||||||
|
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0
|
||||||
|
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0
|
||||||
|
attrs.update(**kwargs)
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
def bert_decoder(self, **kwargs):
|
||||||
|
hf_bert_tokenizer = self.tokenizer
|
||||||
|
attrs = {'vocab': json.dumps(
|
||||||
|
hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))}
|
||||||
|
attrs.update(**kwargs)
|
||||||
|
return attrs
|
||||||
|
|
||||||
def bpe_decoder(self, **kwargs):
|
def bpe_decoder(self, **kwargs):
|
||||||
decoder = self.tokenizer.decoder
|
decoder = self.tokenizer.decoder
|
||||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||||
|
@ -95,22 +114,28 @@ TokenOpParam = namedtuple("TokenOpParam",
|
||||||
"default_inputs"],
|
"default_inputs"],
|
||||||
defaults=(None, None, None, None, None))
|
defaults=(None, None, None, None, None))
|
||||||
|
|
||||||
# fmt: off
|
# @formatter:off
|
||||||
_PROCESSOR_DICT = {
|
_PROCESSOR_DICT = {
|
||||||
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||||
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
"DistilBertTokenizer":
|
||||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||||
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||||
None, None),
|
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||||
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||||
|
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||||
|
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||||
|
None, None, None),
|
||||||
|
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||||
|
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||||
default_inputs={'add_eos': [True]}),
|
default_inputs={'add_eos': [True]}),
|
||||||
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||||
default_inputs={'add_bos': [True]}),
|
default_inputs={'add_bos': [True]}),
|
||||||
}
|
}
|
||||||
# fmt: on
|
# @formatter:on
|
||||||
|
|
||||||
|
|
||||||
class HFTokenizerOnnxGraph:
|
class HFTokenizerOnnxGraph:
|
||||||
|
|
||||||
|
@ -137,31 +162,34 @@ class HFTokenizerOnnxGraph:
|
||||||
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
|
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
|
||||||
cvt = partial(_cvt_func, self.cvt_obj)
|
cvt = partial(_cvt_func, self.cvt_obj)
|
||||||
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||||
|
default_inputs = []
|
||||||
if with_default_inputs:
|
if with_default_inputs:
|
||||||
op_class = SingleOpGraph.get_op_class(_cvt_op)
|
op_class = SingleOpGraph.get_op_class(_cvt_op)
|
||||||
default_inputs = op_class.input_default_values()
|
default_inputs = op_class.input_default_values()
|
||||||
if default_inputs is None:
|
if default_inputs is None:
|
||||||
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op))
|
return g
|
||||||
n_inputs = len(default_inputs)
|
|
||||||
if self.cvt_quadruple.default_inputs is not None:
|
|
||||||
default_inputs.update(self.cvt_quadruple.default_inputs)
|
|
||||||
if len(default_inputs) != n_inputs:
|
|
||||||
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
|
||||||
|
|
||||||
new_initializers = []
|
# add default_inputs into initializers to simplify the model input
|
||||||
|
n_inputs = len(default_inputs)
|
||||||
|
if self.cvt_quadruple.default_inputs is not None:
|
||||||
|
default_inputs.update(self.cvt_quadruple.default_inputs)
|
||||||
|
if len(default_inputs) != n_inputs:
|
||||||
|
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||||
|
|
||||||
for k, v in default_inputs.items():
|
new_initializers = []
|
||||||
input_value_info = next((i for i in g.input if i.name == k), None)
|
|
||||||
if input_value_info is None:
|
|
||||||
raise ValueError("The input {} is not found in the graph".format(k))
|
|
||||||
|
|
||||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
for k, v in default_inputs.items():
|
||||||
value = np.array(v, np_dtype)
|
input_value_info = next((i for i in g.input if i.name == k), None)
|
||||||
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
if input_value_info is None:
|
||||||
g.initializer.extend(new_initializers)
|
raise ValueError("The input {} is not found in the graph".format(k))
|
||||||
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
|
||||||
g.ClearField("input")
|
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
||||||
g.input.extend(new_inputs)
|
value = nparray(v, np_dtype)
|
||||||
|
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
||||||
|
g.initializer.extend(new_initializers)
|
||||||
|
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
||||||
|
g.ClearField("input")
|
||||||
|
g.input.extend(new_inputs)
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def post_processing(self, **kwargs):
|
def post_processing(self, **kwargs):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from ._extensions_pydll import ( # noqa
|
||||||
def get_library_path():
|
def get_library_path():
|
||||||
"""
|
"""
|
||||||
The custom operator library binary path
|
The custom operator library binary path
|
||||||
:return: A string of the this library path.
|
:return: A string of this library path.
|
||||||
"""
|
"""
|
||||||
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
|
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
|
||||||
return mod.__file__
|
return mod.__file__
|
||||||
|
|
|
@ -11,11 +11,11 @@ import numpy as np
|
||||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||||
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
||||||
|
|
||||||
|
|
||||||
_ort_check_passed = False
|
_ort_check_passed = False
|
||||||
try:
|
try:
|
||||||
from packaging import version as _ver
|
from packaging import version as _ver
|
||||||
import onnxruntime as _ort
|
import onnxruntime as _ort
|
||||||
|
|
||||||
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
|
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
|
||||||
_ort_check_passed = True
|
_ort_check_passed = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -37,6 +37,7 @@ def get_opset_version_from_ort():
|
||||||
"1.12": 17,
|
"1.12": 17,
|
||||||
"1.13": 17,
|
"1.13": 17,
|
||||||
"1.14": 18,
|
"1.14": 18,
|
||||||
|
"1.15": 18
|
||||||
}
|
}
|
||||||
|
|
||||||
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
||||||
|
@ -59,6 +60,13 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
||||||
|
|
||||||
|
|
||||||
class OrtPyFunction:
|
class OrtPyFunction:
|
||||||
|
"""
|
||||||
|
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
|
||||||
|
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
|
||||||
|
standard Python function. The order of the function arguments correlates directly with
|
||||||
|
the sequence of the input/output in the ONNX graph.
|
||||||
|
"""
|
||||||
|
|
||||||
def get_ort_session_options(self):
|
def get_ort_session_options(self):
|
||||||
so = _ort.SessionOptions()
|
so = _ort.SessionOptions()
|
||||||
for k, v in self.extra_session_options.items():
|
for k, v in self.extra_session_options.items():
|
||||||
|
@ -66,7 +74,7 @@ class OrtPyFunction:
|
||||||
so.register_custom_ops_library(get_library_path())
|
so.register_custom_ops_library(get_library_path())
|
||||||
return so
|
return so
|
||||||
|
|
||||||
def __init__(self, cpu_only=None):
|
def __init__(self, path_or_model=None, cpu_only=None):
|
||||||
self._onnx_model = None
|
self._onnx_model = None
|
||||||
self.ort_session = None
|
self.ort_session = None
|
||||||
self.default_inputs = {}
|
self.default_inputs = {}
|
||||||
|
@ -75,6 +83,14 @@ class OrtPyFunction:
|
||||||
if _ort.get_device() == 'GPU':
|
if _ort.get_device() == 'GPU':
|
||||||
self.execution_providers = ['CUDAExecutionProvider']
|
self.execution_providers = ['CUDAExecutionProvider']
|
||||||
self.extra_session_options = {}
|
self.extra_session_options = {}
|
||||||
|
mpath = None
|
||||||
|
if isinstance(path_or_model, str):
|
||||||
|
oxml = onnx.load_model(path_or_model)
|
||||||
|
mpath = path_or_model
|
||||||
|
else:
|
||||||
|
oxml = path_or_model
|
||||||
|
if path_or_model is not None:
|
||||||
|
self._bind(oxml, mpath)
|
||||||
|
|
||||||
def create_from_customop(self, op_type, *args, **kwargs):
|
def create_from_customop(self, op_type, *args, **kwargs):
|
||||||
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
|
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
|
||||||
|
@ -130,17 +146,13 @@ class OrtPyFunction:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_customop(cls, op_type, *args, **kwargs):
|
def from_customop(cls, op_type, *args, **kwargs):
|
||||||
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
|
return (cls(cpu_only=cls._get_kwarg_device(kwargs))
|
||||||
|
.create_from_customop(op_type, *args, **kwargs))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, path_or_model, *args, **kwargs):
|
def from_model(cls, path_or_model, *args, **kwargs):
|
||||||
mpath = None
|
fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
|
||||||
if isinstance(path_or_model, str):
|
return fn
|
||||||
oxml = onnx.load_model(path_or_model)
|
|
||||||
mpath = path_or_model
|
|
||||||
else:
|
|
||||||
oxml = path_or_model
|
|
||||||
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
|
|
||||||
|
|
||||||
def _argument_map(self, *args, **kwargs):
|
def _argument_map(self, *args, **kwargs):
|
||||||
idx = 0
|
idx = 0
|
||||||
|
@ -169,7 +181,7 @@ class OrtPyFunction:
|
||||||
|
|
||||||
|
|
||||||
def optimize_model(model_or_file, output_file):
|
def optimize_model(model_or_file, output_file):
|
||||||
sess_options = OrtPyFunction.get_ort_session_options()
|
sess_options = OrtPyFunction().get_ort_session_options()
|
||||||
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
sess_options.optimized_model_filepath = output_file
|
sess_options.optimized_model_filepath = output_file
|
||||||
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
|
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
|
||||||
|
|
|
@ -241,6 +241,6 @@ class WhisperDataProcGraph:
|
||||||
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||||
del g.input[:]
|
del g.input[:]
|
||||||
g.input.extend(inputs)
|
g.input.extend(inputs)
|
||||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
|
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
|
||||||
|
|
||||||
return make_onnx_model(g, opset_version=self.opset_version)
|
return make_onnx_model(g, opset_version=self.opset_version)
|
||||||
|
|
|
@ -50,7 +50,8 @@ class TestAudioCodec(unittest.TestCase):
|
||||||
def test_decoder_resampling(self):
|
def test_decoder_resampling(self):
|
||||||
test_file = util.get_test_data_file('data', 'jfk.flac')
|
test_file = util.get_test_data_file('data', 'jfk.flac')
|
||||||
blob = bytearray(util.read_file(test_file, mode='rb'))
|
blob = bytearray(util.read_file(test_file, mode='rb'))
|
||||||
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
decoder = PyOrtFunction.from_customop(
|
||||||
|
'AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
||||||
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
||||||
self.assertEqual(pcm_tensor.shape, (1, 176000))
|
self.assertEqual(pcm_tensor.shape, (1, 176000))
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
|
||||||
import onnx
|
import onnx
|
||||||
from onnx import onnx_pb as onnx_proto
|
from onnx import onnx_pb as onnx_proto
|
||||||
|
|
||||||
|
|
||||||
_is_torch_available = False
|
_is_torch_available = False
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,106 +1,41 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
import sys
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as _ort
|
from transformers import AutoTokenizer
|
||||||
from packaging import version
|
from onnxruntime_extensions import OrtPyFunction, gen_processing_models
|
||||||
from transformers import AutoTokenizer, WhisperProcessor
|
|
||||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
|
||||||
class TestAutoTokenizer(unittest.TestCase):
|
class TestAutoTokenizer(unittest.TestCase):
|
||||||
|
def test_bert_tokenizer(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
||||||
|
text = "Replace me by any text you'd like."
|
||||||
|
encoded_input = tokenizer(text, return_tensors='np')
|
||||||
|
ort_tok = OrtPyFunction(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||||
|
actual_ids = ort_tok([text])[0]
|
||||||
|
np.testing.assert_array_equal(encoded_input['input_ids'][0], actual_ids)
|
||||||
|
|
||||||
def test_llama_tokenizer(self):
|
def test_llama_tokenizer(self):
|
||||||
# replace the official model name after the model is not gated anymore
|
# replace the official model name after the model is not gated anymore
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np")
|
text = "I was born in 92000, and this is falsé."
|
||||||
|
ids = tokenizer.encode(text, return_tensors="np")
|
||||||
|
|
||||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||||
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0]
|
actual_ids = ort_tok([text])[0]
|
||||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||||
|
|
||||||
def test_t5_tokenizer(self):
|
def test_t5_tokenizer(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||||
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
|
text = "best hotel in bay area."
|
||||||
|
ids = tokenizer.encode(text, return_tensors="np")
|
||||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||||
actual_ids = ort_tok(["best hotel in bay area."])[0]
|
actual_ids = ort_tok([text])[0]
|
||||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||||
|
|
||||||
def test_whisper_overall(self):
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
pre_m, post_m = gen_processing_models(processor,
|
|
||||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
|
||||||
post_kwargs={})
|
|
||||||
|
|
||||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
|
||||||
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
|
|
||||||
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
|
||||||
log_mel = fn_pre(simaudio)
|
|
||||||
|
|
||||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
|
||||||
|
|
||||||
fn_post = OrtPyFunction.from_model(post_m)
|
|
||||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
|
||||||
self.assertEqual(rel[0], "$%&")
|
|
||||||
|
|
||||||
def test_whisper_audio_decoder(self):
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
pre_m, _ = gen_processing_models(processor,
|
|
||||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
|
||||||
|
|
||||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
|
||||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
|
||||||
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
|
||||||
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
|
||||||
|
|
||||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
|
||||||
def test_ort_stft_consistency(self):
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
pre_m, _ = gen_processing_models(processor,
|
|
||||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
|
||||||
|
|
||||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
|
||||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
|
||||||
raw_audio = OrtPyFunction.from_customop(
|
|
||||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
|
||||||
|
|
||||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
|
||||||
expected = input_features['input_features'][0]
|
|
||||||
|
|
||||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
|
||||||
actual = log_mel[0]
|
|
||||||
|
|
||||||
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
|
||||||
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
|
||||||
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
|
||||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
|
||||||
def test_stft_norm_consistency(self):
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
pre_m, _ = gen_processing_models(processor,
|
|
||||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
|
||||||
|
|
||||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
|
||||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
|
||||||
raw_audio = OrtPyFunction.from_customop(
|
|
||||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
|
||||||
|
|
||||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
|
||||||
expected = input_features['input_features'][0]
|
|
||||||
|
|
||||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
|
||||||
actual = log_mel[0]
|
|
||||||
|
|
||||||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
|
||||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -114,5 +114,6 @@ class TestBertTokenizer(unittest.TestCase):
|
||||||
|
|
||||||
print("\n*** Offset mapping tests complete. ***\n")
|
print("\n*** Offset mapping tests complete. ***\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -55,7 +55,6 @@ class TestBertTokenizerOp(unittest.TestCase):
|
||||||
input="cat isnot playing toyssss"
|
input="cat isnot playing toyssss"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_text_to_case1_with_hf_tok(self):
|
def test_text_to_case1_with_hf_tok(self):
|
||||||
ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer)
|
ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer)
|
||||||
model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12)
|
model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12)
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as _ort
|
||||||
|
from packaging import version
|
||||||
|
from transformers import WhisperProcessor
|
||||||
|
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||||
|
class TestHuggingfaceWhisper(unittest.TestCase):
|
||||||
|
def test_whisper_overall(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, post_m = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||||
|
post_kwargs={})
|
||||||
|
|
||||||
|
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||||
|
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
|
||||||
|
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
||||||
|
log_mel = fn_pre(simaudio)
|
||||||
|
|
||||||
|
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||||
|
|
||||||
|
fn_post = OrtPyFunction.from_model(post_m)
|
||||||
|
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
||||||
|
self.assertEqual(rel[0], "$%&")
|
||||||
|
|
||||||
|
def test_whisper_audio_decoder(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||||
|
|
||||||
|
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||||
|
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||||
|
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||||
|
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
||||||
|
|
||||||
|
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||||
|
def test_ort_stft_consistency(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||||
|
|
||||||
|
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||||
|
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||||
|
raw_audio = OrtPyFunction.from_customop(
|
||||||
|
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||||
|
|
||||||
|
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||||
|
expected = input_features['input_features'][0]
|
||||||
|
|
||||||
|
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||||
|
actual = log_mel[0]
|
||||||
|
|
||||||
|
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
||||||
|
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
||||||
|
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||||
|
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||||
|
def test_stft_norm_consistency(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||||
|
|
||||||
|
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||||
|
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||||
|
raw_audio = OrtPyFunction.from_customop(
|
||||||
|
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||||
|
|
||||||
|
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||||
|
expected = input_features['input_features'][0]
|
||||||
|
|
||||||
|
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||||
|
actual = log_mel[0]
|
||||||
|
|
||||||
|
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||||
|
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Загрузка…
Ссылка в новой задаче