Add Bert tokenizer in the supported model list and code refinement (#503)

* Add Bert tokenizer in the supported model list and the related code refinement

* utest fix
This commit is contained in:
Wenbing Li 2023-08-02 14:01:36 -07:00 коммит произвёл GitHub
Родитель 6209804ee9
Коммит 922b7cc387
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 226 добавлений и 139 удалений

Просмотреть файл

@ -4,24 +4,42 @@
###############################################################################
"""
The entry point to onnxruntime-extensions package.
The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and
post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python.
This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime.
"""
__author__ = "Microsoft"
__all__ = [
'gen_processing_models',
'get_library_path',
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
'enable_py_op',
'expand_onnx_inputs',
'hook_model_op',
'default_opset_domain',
'OrtPyFunction', 'PyOrtFunction',
'optimize_model',
'make_onnx_model',
'ONNXRuntimeError',
'hash_64',
'__version__',
]
from ._version import __version__
from ._ocos import get_library_path # noqa
from ._ocos import Opdef, PyCustomOpDef # noqa
from ._ocos import hash_64 # noqa
from ._ocos import enable_py_op # noqa
from ._ocos import expand_onnx_inputs # noqa
from ._ocos import hook_model_op # noqa
from ._ocos import default_opset_domain # noqa
from ._cuops import * # noqa
from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef
from ._ocos import hash_64
from ._ocos import enable_py_op
from ._ocos import expand_onnx_inputs
from ._ocos import hook_model_op
from ._ocos import default_opset_domain
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
from .cvt import gen_processing_models
# rename the implementation with a more formal name
onnx_op = Opdef.declare
PyOp = PyCustomOpDef

Просмотреть файл

@ -253,7 +253,9 @@ class BertTokenizer(CustomOp):
def serialize_attr(cls, attrs):
attrs_data = {}
for k_, v_ in attrs.items():
if k_ == 'vocab_file':
if k_ == 'vocab':
attrs_data['vocab_file'] = v_
elif k_ == 'vocab_file':
with open(v_, "r", encoding='utf-8') as model_file:
lines = model_file.readlines()
attrs_data[k_] = '\n'.join(lines)

Просмотреть файл

@ -2,6 +2,8 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################
from typing import Callable
class PyCustomOpDef:
undefined: int = ...
@ -21,6 +23,8 @@ class PyCustomOpDef:
dt_complex64: int = ...
dt_complex128: int = ...
dt_bfloat16: int = ...
def install_hooker(self, invocation_handler: Callable) -> None:
...
...

Просмотреть файл

@ -9,9 +9,9 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
import json
import onnx
import numpy as np
from numpy import array as nparray
from functools import partial
from collections import namedtuple
from collections import namedtuple, OrderedDict
from ._cuops import CustomOpConverter, SingleOpGraph
from .util import read_file
@ -32,6 +32,25 @@ class HFTokenizerConverter(CustomOpConverter):
attrs.update(**kwargs)
return attrs
def bert_tokenizer(self, **kwargs):
hf_bert_tokenizer = self.tokenizer
# has to be sorted since the id of token was generated automatically.
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
vocab = '\n'.join(ordered_vocab.keys())
attrs = dict(vocab=vocab)
init_kwargs = hf_bert_tokenizer.init_kwargs
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0
attrs.update(**kwargs)
return attrs
def bert_decoder(self, **kwargs):
hf_bert_tokenizer = self.tokenizer
attrs = {'vocab': json.dumps(
hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))}
attrs.update(**kwargs)
return attrs
def bpe_decoder(self, **kwargs):
decoder = self.tokenizer.decoder
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
@ -95,22 +114,28 @@ TokenOpParam = namedtuple("TokenOpParam",
"default_inputs"],
defaults=(None, None, None, None, None))
# fmt: off
# @formatter:off
_PROCESSOR_DICT = {
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
None, None),
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
"DistilBertTokenizer":
TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
None, None, None),
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_eos': [True]}),
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True]}),
}
# fmt: on
# @formatter:on
class HFTokenizerOnnxGraph:
@ -137,31 +162,34 @@ class HFTokenizerOnnxGraph:
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
default_inputs = []
if with_default_inputs:
op_class = SingleOpGraph.get_op_class(_cvt_op)
default_inputs = op_class.input_default_values()
if default_inputs is None:
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op))
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if len(default_inputs) != n_inputs:
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
return g
new_initializers = []
# add default_inputs into initializers to simplify the model input
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if len(default_inputs) != n_inputs:
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
for k, v in default_inputs.items():
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError("The input {} is not found in the graph".format(k))
new_initializers = []
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
value = np.array(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)
for k, v in default_inputs.items():
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError("The input {} is not found in the graph".format(k))
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
value = nparray(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)
return g
def post_processing(self, **kwargs):

Просмотреть файл

@ -17,7 +17,7 @@ from ._extensions_pydll import ( # noqa
def get_library_path():
"""
The custom operator library binary path
:return: A string of the this library path.
:return: A string of this library path.
"""
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
return mod.__file__

Просмотреть файл

@ -11,11 +11,11 @@ import numpy as np
from ._ocos import default_opset_domain, get_library_path # noqa
from ._cuops import onnx, onnx_proto, SingleOpGraph
_ort_check_passed = False
try:
from packaging import version as _ver
import onnxruntime as _ort
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
_ort_check_passed = True
except ImportError:
@ -37,6 +37,7 @@ def get_opset_version_from_ort():
"1.12": 17,
"1.13": 17,
"1.14": 18,
"1.15": 18
}
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
@ -59,6 +60,13 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
class OrtPyFunction:
"""
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
standard Python function. The order of the function arguments correlates directly with
the sequence of the input/output in the ONNX graph.
"""
def get_ort_session_options(self):
so = _ort.SessionOptions()
for k, v in self.extra_session_options.items():
@ -66,7 +74,7 @@ class OrtPyFunction:
so.register_custom_ops_library(get_library_path())
return so
def __init__(self, cpu_only=None):
def __init__(self, path_or_model=None, cpu_only=None):
self._onnx_model = None
self.ort_session = None
self.default_inputs = {}
@ -75,6 +83,14 @@ class OrtPyFunction:
if _ort.get_device() == 'GPU':
self.execution_providers = ['CUDAExecutionProvider']
self.extra_session_options = {}
mpath = None
if isinstance(path_or_model, str):
oxml = onnx.load_model(path_or_model)
mpath = path_or_model
else:
oxml = path_or_model
if path_or_model is not None:
self._bind(oxml, mpath)
def create_from_customop(self, op_type, *args, **kwargs):
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
@ -130,17 +146,13 @@ class OrtPyFunction:
@classmethod
def from_customop(cls, op_type, *args, **kwargs):
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
return (cls(cpu_only=cls._get_kwarg_device(kwargs))
.create_from_customop(op_type, *args, **kwargs))
@classmethod
def from_model(cls, path_or_model, *args, **kwargs):
mpath = None
if isinstance(path_or_model, str):
oxml = onnx.load_model(path_or_model)
mpath = path_or_model
else:
oxml = path_or_model
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
return fn
def _argument_map(self, *args, **kwargs):
idx = 0
@ -169,7 +181,7 @@ class OrtPyFunction:
def optimize_model(model_or_file, output_file):
sess_options = OrtPyFunction.get_ort_session_options()
sess_options = OrtPyFunction().get_ort_session_options()
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = output_file
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)

Просмотреть файл

@ -241,6 +241,6 @@ class WhisperDataProcGraph:
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -50,7 +50,8 @@ class TestAudioCodec(unittest.TestCase):
def test_decoder_resampling(self):
test_file = util.get_test_data_file('data', 'jfk.flac')
blob = bytearray(util.read_file(test_file, mode='rb'))
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
decoder = PyOrtFunction.from_customop(
'AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
self.assertEqual(pcm_tensor.shape, (1, 176000))

Просмотреть файл

@ -8,7 +8,6 @@ from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
import onnx
from onnx import onnx_pb as onnx_proto
_is_torch_available = False
try:
import torch

Просмотреть файл

@ -1,106 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import onnxruntime as _ort
from packaging import version
from transformers import AutoTokenizer, WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
from transformers import AutoTokenizer
from onnxruntime_extensions import OrtPyFunction, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestAutoTokenizer(unittest.TestCase):
def test_bert_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='np')
ort_tok = OrtPyFunction(gen_processing_models(tokenizer, pre_kwargs={})[0])
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(encoded_input['input_ids'][0], actual_ids)
def test_llama_tokenizer(self):
# replace the official model name after the model is not gated anymore
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np")
text = "I was born in 92000, and this is falsé."
ids = tokenizer.encode(text, return_tensors="np")
ort_tok = OrtPyFunction.from_model(gen_processing_models(
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0]
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids[0], actual_ids)
def test_t5_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
text = "best hotel in bay area."
ids = tokenizer.encode(text, return_tensors="np")
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
actual_ids = ort_tok(["best hotel in bay area."])[0]
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids[0], actual_ids)
def test_whisper_overall(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&")
def test_whisper_audio_decoder(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -114,5 +114,6 @@ class TestBertTokenizer(unittest.TestCase):
print("\n*** Offset mapping tests complete. ***\n")
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -55,7 +55,6 @@ class TestBertTokenizerOp(unittest.TestCase):
input="cat isnot playing toyssss"
)
def test_text_to_case1_with_hf_tok(self):
ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer)
model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12)

88
test/test_whisper.py Normal file
Просмотреть файл

@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import onnxruntime as _ort
from packaging import version
from transformers import WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestHuggingfaceWhisper(unittest.TestCase):
def test_whisper_overall(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&")
def test_whisper_audio_decoder(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == "__main__":
unittest.main()