Add Bert tokenizer in the supported model list and code refinement (#503)

* Add Bert tokenizer in the supported model list and the related code refinement

* utest fix
This commit is contained in:
Wenbing Li 2023-08-02 14:01:36 -07:00 коммит произвёл GitHub
Родитель 6209804ee9
Коммит 922b7cc387
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 226 добавлений и 139 удалений

Просмотреть файл

@ -4,24 +4,42 @@
############################################################################### ###############################################################################
""" """
The entry point to onnxruntime-extensions package. The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and
post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python.
This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime.
""" """
__author__ = "Microsoft" __author__ = "Microsoft"
__all__ = [
'gen_processing_models',
'get_library_path',
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
'enable_py_op',
'expand_onnx_inputs',
'hook_model_op',
'default_opset_domain',
'OrtPyFunction', 'PyOrtFunction',
'optimize_model',
'make_onnx_model',
'ONNXRuntimeError',
'hash_64',
'__version__',
]
from ._version import __version__ from ._version import __version__
from ._ocos import get_library_path # noqa from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef # noqa from ._ocos import Opdef, PyCustomOpDef
from ._ocos import hash_64 # noqa from ._ocos import hash_64
from ._ocos import enable_py_op # noqa from ._ocos import enable_py_op
from ._ocos import expand_onnx_inputs # noqa from ._ocos import expand_onnx_inputs
from ._ocos import hook_model_op # noqa from ._ocos import hook_model_op
from ._ocos import default_opset_domain # noqa from ._ocos import default_opset_domain
from ._cuops import * # noqa from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
from .cvt import gen_processing_models from .cvt import gen_processing_models
# rename the implementation with a more formal name
onnx_op = Opdef.declare onnx_op = Opdef.declare
PyOp = PyCustomOpDef PyOp = PyCustomOpDef

Просмотреть файл

@ -253,7 +253,9 @@ class BertTokenizer(CustomOp):
def serialize_attr(cls, attrs): def serialize_attr(cls, attrs):
attrs_data = {} attrs_data = {}
for k_, v_ in attrs.items(): for k_, v_ in attrs.items():
if k_ == 'vocab_file': if k_ == 'vocab':
attrs_data['vocab_file'] = v_
elif k_ == 'vocab_file':
with open(v_, "r", encoding='utf-8') as model_file: with open(v_, "r", encoding='utf-8') as model_file:
lines = model_file.readlines() lines = model_file.readlines()
attrs_data[k_] = '\n'.join(lines) attrs_data[k_] = '\n'.join(lines)

Просмотреть файл

@ -2,6 +2,8 @@
# Licensed under the MIT License. See License.txt in the project root for # Licensed under the MIT License. See License.txt in the project root for
# license information. # license information.
############################################################################### ###############################################################################
from typing import Callable
class PyCustomOpDef: class PyCustomOpDef:
undefined: int = ... undefined: int = ...
@ -21,6 +23,8 @@ class PyCustomOpDef:
dt_complex64: int = ... dt_complex64: int = ...
dt_complex128: int = ... dt_complex128: int = ...
dt_bfloat16: int = ... dt_bfloat16: int = ...
def install_hooker(self, invocation_handler: Callable) -> None:
...
... ...

Просмотреть файл

@ -9,9 +9,9 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
import json import json
import onnx import onnx
import numpy as np from numpy import array as nparray
from functools import partial from functools import partial
from collections import namedtuple from collections import namedtuple, OrderedDict
from ._cuops import CustomOpConverter, SingleOpGraph from ._cuops import CustomOpConverter, SingleOpGraph
from .util import read_file from .util import read_file
@ -32,6 +32,25 @@ class HFTokenizerConverter(CustomOpConverter):
attrs.update(**kwargs) attrs.update(**kwargs)
return attrs return attrs
def bert_tokenizer(self, **kwargs):
hf_bert_tokenizer = self.tokenizer
# has to be sorted since the id of token was generated automatically.
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
vocab = '\n'.join(ordered_vocab.keys())
attrs = dict(vocab=vocab)
init_kwargs = hf_bert_tokenizer.init_kwargs
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0
attrs.update(**kwargs)
return attrs
def bert_decoder(self, **kwargs):
hf_bert_tokenizer = self.tokenizer
attrs = {'vocab': json.dumps(
hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))}
attrs.update(**kwargs)
return attrs
def bpe_decoder(self, **kwargs): def bpe_decoder(self, **kwargs):
decoder = self.tokenizer.decoder decoder = self.tokenizer.decoder
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)]) id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
@ -95,22 +114,28 @@ TokenOpParam = namedtuple("TokenOpParam",
"default_inputs"], "default_inputs"],
defaults=(None, None, None, None, None)) defaults=(None, None, None, None, None))
# fmt: off # @formatter:off
_PROCESSOR_DICT = { _PROCESSOR_DICT = {
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer, "BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder), 'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer, "DistilBertTokenizer":
'BpeDecoder', HFTokenizerConverter.bpe_decoder), TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer, 'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
None, None), "GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder, "ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
None, None, None),
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_eos': [True]}), default_inputs={'add_eos': [True]}),
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, "LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder, "SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True]}), default_inputs={'add_bos': [True]}),
} }
# fmt: on # @formatter:on
class HFTokenizerOnnxGraph: class HFTokenizerOnnxGraph:
@ -137,31 +162,34 @@ class HFTokenizerOnnxGraph:
_cvt_func = self.cvt_quadruple.pre_attribute_cvt _cvt_func = self.cvt_quadruple.pre_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj) cvt = partial(_cvt_func, self.cvt_obj)
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs) g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
default_inputs = []
if with_default_inputs: if with_default_inputs:
op_class = SingleOpGraph.get_op_class(_cvt_op) op_class = SingleOpGraph.get_op_class(_cvt_op)
default_inputs = op_class.input_default_values() default_inputs = op_class.input_default_values()
if default_inputs is None: if default_inputs is None:
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op)) return g
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if len(default_inputs) != n_inputs:
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
new_initializers = [] # add default_inputs into initializers to simplify the model input
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if len(default_inputs) != n_inputs:
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
for k, v in default_inputs.items(): new_initializers = []
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError("The input {} is not found in the graph".format(k))
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type) for k, v in default_inputs.items():
value = np.array(v, np_dtype) input_value_info = next((i for i in g.input if i.name == k), None)
new_initializers.append(onnx.numpy_helper.from_array(value, k)) if input_value_info is None:
g.initializer.extend(new_initializers) raise ValueError("The input {} is not found in the graph".format(k))
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input") np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
g.input.extend(new_inputs) value = nparray(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)
return g return g
def post_processing(self, **kwargs): def post_processing(self, **kwargs):

Просмотреть файл

@ -17,7 +17,7 @@ from ._extensions_pydll import ( # noqa
def get_library_path(): def get_library_path():
""" """
The custom operator library binary path The custom operator library binary path
:return: A string of the this library path. :return: A string of this library path.
""" """
mod = sys.modules['onnxruntime_extensions._extensions_pydll'] mod = sys.modules['onnxruntime_extensions._extensions_pydll']
return mod.__file__ return mod.__file__

Просмотреть файл

@ -11,11 +11,11 @@ import numpy as np
from ._ocos import default_opset_domain, get_library_path # noqa from ._ocos import default_opset_domain, get_library_path # noqa
from ._cuops import onnx, onnx_proto, SingleOpGraph from ._cuops import onnx, onnx_proto, SingleOpGraph
_ort_check_passed = False _ort_check_passed = False
try: try:
from packaging import version as _ver from packaging import version as _ver
import onnxruntime as _ort import onnxruntime as _ort
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"): if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
_ort_check_passed = True _ort_check_passed = True
except ImportError: except ImportError:
@ -37,6 +37,7 @@ def get_opset_version_from_ort():
"1.12": 17, "1.12": 17,
"1.13": 17, "1.13": 17,
"1.14": 18, "1.14": 18,
"1.15": 18
} }
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
@ -59,6 +60,13 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
class OrtPyFunction: class OrtPyFunction:
"""
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
standard Python function. The order of the function arguments correlates directly with
the sequence of the input/output in the ONNX graph.
"""
def get_ort_session_options(self): def get_ort_session_options(self):
so = _ort.SessionOptions() so = _ort.SessionOptions()
for k, v in self.extra_session_options.items(): for k, v in self.extra_session_options.items():
@ -66,7 +74,7 @@ class OrtPyFunction:
so.register_custom_ops_library(get_library_path()) so.register_custom_ops_library(get_library_path())
return so return so
def __init__(self, cpu_only=None): def __init__(self, path_or_model=None, cpu_only=None):
self._onnx_model = None self._onnx_model = None
self.ort_session = None self.ort_session = None
self.default_inputs = {} self.default_inputs = {}
@ -75,6 +83,14 @@ class OrtPyFunction:
if _ort.get_device() == 'GPU': if _ort.get_device() == 'GPU':
self.execution_providers = ['CUDAExecutionProvider'] self.execution_providers = ['CUDAExecutionProvider']
self.extra_session_options = {} self.extra_session_options = {}
mpath = None
if isinstance(path_or_model, str):
oxml = onnx.load_model(path_or_model)
mpath = path_or_model
else:
oxml = path_or_model
if path_or_model is not None:
self._bind(oxml, mpath)
def create_from_customop(self, op_type, *args, **kwargs): def create_from_customop(self, op_type, *args, **kwargs):
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs) graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
@ -130,17 +146,13 @@ class OrtPyFunction:
@classmethod @classmethod
def from_customop(cls, op_type, *args, **kwargs): def from_customop(cls, op_type, *args, **kwargs):
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs) return (cls(cpu_only=cls._get_kwarg_device(kwargs))
.create_from_customop(op_type, *args, **kwargs))
@classmethod @classmethod
def from_model(cls, path_or_model, *args, **kwargs): def from_model(cls, path_or_model, *args, **kwargs):
mpath = None fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
if isinstance(path_or_model, str): return fn
oxml = onnx.load_model(path_or_model)
mpath = path_or_model
else:
oxml = path_or_model
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
def _argument_map(self, *args, **kwargs): def _argument_map(self, *args, **kwargs):
idx = 0 idx = 0
@ -169,7 +181,7 @@ class OrtPyFunction:
def optimize_model(model_or_file, output_file): def optimize_model(model_or_file, output_file):
sess_options = OrtPyFunction.get_ort_session_options() sess_options = OrtPyFunction().get_ort_session_options()
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = output_file sess_options.optimized_model_filepath = output_file
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str) _ort.InferenceSession(model_or_file if isinstance(model_or_file, str)

Просмотреть файл

@ -241,6 +241,6 @@ class WhisperDataProcGraph:
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:] del g.input[:]
g.input.extend(inputs) g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text'])) g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
return make_onnx_model(g, opset_version=self.opset_version) return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -50,7 +50,8 @@ class TestAudioCodec(unittest.TestCase):
def test_decoder_resampling(self): def test_decoder_resampling(self):
test_file = util.get_test_data_file('data', 'jfk.flac') test_file = util.get_test_data_file('data', 'jfk.flac')
blob = bytearray(util.read_file(test_file, mode='rb')) blob = bytearray(util.read_file(test_file, mode='rb'))
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1) decoder = PyOrtFunction.from_customop(
'AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,))) pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
self.assertEqual(pcm_tensor.shape, (1, 176000)) self.assertEqual(pcm_tensor.shape, (1, 176000))

Просмотреть файл

@ -8,7 +8,6 @@ from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
import onnx import onnx
from onnx import onnx_pb as onnx_proto from onnx import onnx_pb as onnx_proto
_is_torch_available = False _is_torch_available = False
try: try:
import torch import torch

Просмотреть файл

@ -1,106 +1,41 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import sys
import unittest import unittest
import numpy as np import numpy as np
import onnxruntime as _ort from transformers import AutoTokenizer
from packaging import version from onnxruntime_extensions import OrtPyFunction, gen_processing_models
from transformers import AutoTokenizer, WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestAutoTokenizer(unittest.TestCase): class TestAutoTokenizer(unittest.TestCase):
def test_bert_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='np')
ort_tok = OrtPyFunction(gen_processing_models(tokenizer, pre_kwargs={})[0])
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(encoded_input['input_ids'][0], actual_ids)
def test_llama_tokenizer(self): def test_llama_tokenizer(self):
# replace the official model name after the model is not gated anymore # replace the official model name after the model is not gated anymore
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np") text = "I was born in 92000, and this is falsé."
ids = tokenizer.encode(text, return_tensors="np")
ort_tok = OrtPyFunction.from_model(gen_processing_models( ort_tok = OrtPyFunction.from_model(gen_processing_models(
tokenizer, tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0]) pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0] actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids[0], actual_ids) np.testing.assert_array_equal(ids[0], actual_ids)
def test_t5_tokenizer(self): def test_t5_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np") text = "best hotel in bay area."
ids = tokenizer.encode(text, return_tensors="np")
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0]) ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
actual_ids = ort_tok(["best hotel in bay area."])[0] actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids[0], actual_ids) np.testing.assert_array_equal(ids[0], actual_ids)
def test_whisper_overall(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&")
def test_whisper_audio_decoder(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Просмотреть файл

@ -114,5 +114,6 @@ class TestBertTokenizer(unittest.TestCase):
print("\n*** Offset mapping tests complete. ***\n") print("\n*** Offset mapping tests complete. ***\n")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -55,7 +55,6 @@ class TestBertTokenizerOp(unittest.TestCase):
input="cat isnot playing toyssss" input="cat isnot playing toyssss"
) )
def test_text_to_case1_with_hf_tok(self): def test_text_to_case1_with_hf_tok(self):
ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer) ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer)
model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12) model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12)

88
test/test_whisper.py Normal file
Просмотреть файл

@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import onnxruntime as _ort
from packaging import version
from transformers import WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestHuggingfaceWhisper(unittest.TestCase):
def test_whisper_overall(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&")
def test_whisper_audio_decoder(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == "__main__":
unittest.main()