Add Bert tokenizer in the supported model list and code refinement (#503)
* Add Bert tokenizer in the supported model list and the related code refinement * utest fix
This commit is contained in:
Родитель
6209804ee9
Коммит
922b7cc387
|
@ -4,24 +4,42 @@
|
|||
###############################################################################
|
||||
|
||||
"""
|
||||
The entry point to onnxruntime-extensions package.
|
||||
The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and
|
||||
post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python.
|
||||
This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime.
|
||||
"""
|
||||
|
||||
__author__ = "Microsoft"
|
||||
|
||||
__all__ = [
|
||||
'gen_processing_models',
|
||||
'get_library_path',
|
||||
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
|
||||
'enable_py_op',
|
||||
'expand_onnx_inputs',
|
||||
'hook_model_op',
|
||||
'default_opset_domain',
|
||||
'OrtPyFunction', 'PyOrtFunction',
|
||||
'optimize_model',
|
||||
'make_onnx_model',
|
||||
'ONNXRuntimeError',
|
||||
'hash_64',
|
||||
'__version__',
|
||||
]
|
||||
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
from ._ocos import hash_64 # noqa
|
||||
from ._ocos import enable_py_op # noqa
|
||||
from ._ocos import expand_onnx_inputs # noqa
|
||||
from ._ocos import hook_model_op # noqa
|
||||
from ._ocos import default_opset_domain # noqa
|
||||
from ._cuops import * # noqa
|
||||
from ._ocos import get_library_path
|
||||
from ._ocos import Opdef, PyCustomOpDef
|
||||
from ._ocos import hash_64
|
||||
from ._ocos import enable_py_op
|
||||
from ._ocos import expand_onnx_inputs
|
||||
from ._ocos import hook_model_op
|
||||
from ._ocos import default_opset_domain
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
from .cvt import gen_processing_models
|
||||
|
||||
# rename the implementation with a more formal name
|
||||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
|
|
@ -253,7 +253,9 @@ class BertTokenizer(CustomOp):
|
|||
def serialize_attr(cls, attrs):
|
||||
attrs_data = {}
|
||||
for k_, v_ in attrs.items():
|
||||
if k_ == 'vocab_file':
|
||||
if k_ == 'vocab':
|
||||
attrs_data['vocab_file'] = v_
|
||||
elif k_ == 'vocab_file':
|
||||
with open(v_, "r", encoding='utf-8') as model_file:
|
||||
lines = model_file.readlines()
|
||||
attrs_data[k_] = '\n'.join(lines)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class PyCustomOpDef:
|
||||
undefined: int = ...
|
||||
|
@ -21,6 +23,8 @@ class PyCustomOpDef:
|
|||
dt_complex64: int = ...
|
||||
dt_complex128: int = ...
|
||||
dt_bfloat16: int = ...
|
||||
def install_hooker(self, invocation_handler: Callable) -> None:
|
||||
...
|
||||
...
|
||||
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
|||
|
||||
import json
|
||||
import onnx
|
||||
import numpy as np
|
||||
from numpy import array as nparray
|
||||
from functools import partial
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||
from .util import read_file
|
||||
|
@ -32,6 +32,25 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bert_tokenizer(self, **kwargs):
|
||||
hf_bert_tokenizer = self.tokenizer
|
||||
# has to be sorted since the id of token was generated automatically.
|
||||
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||
vocab = '\n'.join(ordered_vocab.keys())
|
||||
attrs = dict(vocab=vocab)
|
||||
init_kwargs = hf_bert_tokenizer.init_kwargs
|
||||
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0
|
||||
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bert_decoder(self, **kwargs):
|
||||
hf_bert_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))}
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bpe_decoder(self, **kwargs):
|
||||
decoder = self.tokenizer.decoder
|
||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||
|
@ -95,22 +114,28 @@ TokenOpParam = namedtuple("TokenOpParam",
|
|||
"default_inputs"],
|
||||
defaults=(None, None, None, None, None))
|
||||
|
||||
# fmt: off
|
||||
# @formatter:off
|
||||
_PROCESSOR_DICT = {
|
||||
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None),
|
||||
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"DistilBertTokenizer":
|
||||
TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None, None),
|
||||
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_eos': [True]}),
|
||||
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True]}),
|
||||
}
|
||||
# fmt: on
|
||||
# @formatter:on
|
||||
|
||||
|
||||
class HFTokenizerOnnxGraph:
|
||||
|
||||
|
@ -137,31 +162,34 @@ class HFTokenizerOnnxGraph:
|
|||
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
default_inputs = []
|
||||
if with_default_inputs:
|
||||
op_class = SingleOpGraph.get_op_class(_cvt_op)
|
||||
default_inputs = op_class.input_default_values()
|
||||
if default_inputs is None:
|
||||
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op))
|
||||
n_inputs = len(default_inputs)
|
||||
if self.cvt_quadruple.default_inputs is not None:
|
||||
default_inputs.update(self.cvt_quadruple.default_inputs)
|
||||
if len(default_inputs) != n_inputs:
|
||||
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||
return g
|
||||
|
||||
new_initializers = []
|
||||
# add default_inputs into initializers to simplify the model input
|
||||
n_inputs = len(default_inputs)
|
||||
if self.cvt_quadruple.default_inputs is not None:
|
||||
default_inputs.update(self.cvt_quadruple.default_inputs)
|
||||
if len(default_inputs) != n_inputs:
|
||||
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||
|
||||
for k, v in default_inputs.items():
|
||||
input_value_info = next((i for i in g.input if i.name == k), None)
|
||||
if input_value_info is None:
|
||||
raise ValueError("The input {} is not found in the graph".format(k))
|
||||
new_initializers = []
|
||||
|
||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
||||
value = np.array(v, np_dtype)
|
||||
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
||||
g.initializer.extend(new_initializers)
|
||||
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
||||
g.ClearField("input")
|
||||
g.input.extend(new_inputs)
|
||||
for k, v in default_inputs.items():
|
||||
input_value_info = next((i for i in g.input if i.name == k), None)
|
||||
if input_value_info is None:
|
||||
raise ValueError("The input {} is not found in the graph".format(k))
|
||||
|
||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
||||
value = nparray(v, np_dtype)
|
||||
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
||||
g.initializer.extend(new_initializers)
|
||||
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
||||
g.ClearField("input")
|
||||
g.input.extend(new_inputs)
|
||||
return g
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
|
|
|
@ -17,7 +17,7 @@ from ._extensions_pydll import ( # noqa
|
|||
def get_library_path():
|
||||
"""
|
||||
The custom operator library binary path
|
||||
:return: A string of the this library path.
|
||||
:return: A string of this library path.
|
||||
"""
|
||||
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
|
||||
return mod.__file__
|
||||
|
|
|
@ -11,11 +11,11 @@ import numpy as np
|
|||
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
||||
|
||||
|
||||
_ort_check_passed = False
|
||||
try:
|
||||
from packaging import version as _ver
|
||||
import onnxruntime as _ort
|
||||
|
||||
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
|
||||
_ort_check_passed = True
|
||||
except ImportError:
|
||||
|
@ -37,6 +37,7 @@ def get_opset_version_from_ort():
|
|||
"1.12": 17,
|
||||
"1.13": 17,
|
||||
"1.14": 18,
|
||||
"1.15": 18
|
||||
}
|
||||
|
||||
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
||||
|
@ -59,6 +60,13 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
|||
|
||||
|
||||
class OrtPyFunction:
|
||||
"""
|
||||
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
|
||||
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
|
||||
standard Python function. The order of the function arguments correlates directly with
|
||||
the sequence of the input/output in the ONNX graph.
|
||||
"""
|
||||
|
||||
def get_ort_session_options(self):
|
||||
so = _ort.SessionOptions()
|
||||
for k, v in self.extra_session_options.items():
|
||||
|
@ -66,7 +74,7 @@ class OrtPyFunction:
|
|||
so.register_custom_ops_library(get_library_path())
|
||||
return so
|
||||
|
||||
def __init__(self, cpu_only=None):
|
||||
def __init__(self, path_or_model=None, cpu_only=None):
|
||||
self._onnx_model = None
|
||||
self.ort_session = None
|
||||
self.default_inputs = {}
|
||||
|
@ -75,6 +83,14 @@ class OrtPyFunction:
|
|||
if _ort.get_device() == 'GPU':
|
||||
self.execution_providers = ['CUDAExecutionProvider']
|
||||
self.extra_session_options = {}
|
||||
mpath = None
|
||||
if isinstance(path_or_model, str):
|
||||
oxml = onnx.load_model(path_or_model)
|
||||
mpath = path_or_model
|
||||
else:
|
||||
oxml = path_or_model
|
||||
if path_or_model is not None:
|
||||
self._bind(oxml, mpath)
|
||||
|
||||
def create_from_customop(self, op_type, *args, **kwargs):
|
||||
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
|
||||
|
@ -130,17 +146,13 @@ class OrtPyFunction:
|
|||
|
||||
@classmethod
|
||||
def from_customop(cls, op_type, *args, **kwargs):
|
||||
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
|
||||
return (cls(cpu_only=cls._get_kwarg_device(kwargs))
|
||||
.create_from_customop(op_type, *args, **kwargs))
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, path_or_model, *args, **kwargs):
|
||||
mpath = None
|
||||
if isinstance(path_or_model, str):
|
||||
oxml = onnx.load_model(path_or_model)
|
||||
mpath = path_or_model
|
||||
else:
|
||||
oxml = path_or_model
|
||||
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
|
||||
fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
|
||||
return fn
|
||||
|
||||
def _argument_map(self, *args, **kwargs):
|
||||
idx = 0
|
||||
|
@ -169,7 +181,7 @@ class OrtPyFunction:
|
|||
|
||||
|
||||
def optimize_model(model_or_file, output_file):
|
||||
sess_options = OrtPyFunction.get_ort_session_options()
|
||||
sess_options = OrtPyFunction().get_ort_session_options()
|
||||
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
sess_options.optimized_model_filepath = output_file
|
||||
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
|
||||
|
|
|
@ -241,6 +241,6 @@ class WhisperDataProcGraph:
|
|||
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||
del g.input[:]
|
||||
g.input.extend(inputs)
|
||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
|
||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
|
||||
|
||||
return make_onnx_model(g, opset_version=self.opset_version)
|
||||
|
|
|
@ -50,7 +50,8 @@ class TestAudioCodec(unittest.TestCase):
|
|||
def test_decoder_resampling(self):
|
||||
test_file = util.get_test_data_file('data', 'jfk.flac')
|
||||
blob = bytearray(util.read_file(test_file, mode='rb'))
|
||||
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
||||
decoder = PyOrtFunction.from_customop(
|
||||
'AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
||||
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
||||
self.assertEqual(pcm_tensor.shape, (1, 176000))
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
|
|||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
|
||||
_is_torch_available = False
|
||||
try:
|
||||
import torch
|
||||
|
|
|
@ -1,106 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
from packaging import version
|
||||
from transformers import AutoTokenizer, WhisperProcessor
|
||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||
from transformers import AutoTokenizer
|
||||
from onnxruntime_extensions import OrtPyFunction, gen_processing_models
|
||||
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
class TestAutoTokenizer(unittest.TestCase):
|
||||
def test_bert_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text, return_tensors='np')
|
||||
ort_tok = OrtPyFunction(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(encoded_input['input_ids'][0], actual_ids)
|
||||
|
||||
def test_llama_tokenizer(self):
|
||||
# replace the official model name after the model is not gated anymore
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np")
|
||||
text = "I was born in 92000, and this is falsé."
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
||||
tokenizer,
|
||||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0]
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_t5_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
|
||||
text = "best hotel in bay area."
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||
actual_ids = ort_tok(["best hotel in bay area."])[0]
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_whisper_overall(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, post_m = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
|
||||
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
||||
log_mel = fn_pre(simaudio)
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
fn_post = OrtPyFunction.from_model(post_m)
|
||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
||||
self.assertEqual(rel[0], "$%&")
|
||||
|
||||
def test_whisper_audio_decoder(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_ort_stft_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
||||
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
||||
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_stft_norm_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -114,5 +114,6 @@ class TestBertTokenizer(unittest.TestCase):
|
|||
|
||||
print("\n*** Offset mapping tests complete. ***\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -55,7 +55,6 @@ class TestBertTokenizerOp(unittest.TestCase):
|
|||
input="cat isnot playing toyssss"
|
||||
)
|
||||
|
||||
|
||||
def test_text_to_case1_with_hf_tok(self):
|
||||
ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer)
|
||||
model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
from packaging import version
|
||||
from transformers import WhisperProcessor
|
||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
class TestHuggingfaceWhisper(unittest.TestCase):
|
||||
def test_whisper_overall(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, post_m = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
|
||||
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
||||
log_mel = fn_pre(simaudio)
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
fn_post = OrtPyFunction.from_model(post_m)
|
||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
||||
self.assertEqual(rel[0], "$%&")
|
||||
|
||||
def test_whisper_audio_decoder(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_ort_stft_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
||||
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
||||
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_stft_norm_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче