Update whisper model test cases and e2e example (#496)

* Update whisper model test cases and e2e example

* fix unit test on windows

* more refinement

* utest fix
This commit is contained in:
Wenbing Li 2023-07-21 15:27:02 -07:00 коммит произвёл GitHub
Родитель 06d5a8d781
Коммит 62d8598b6b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 209 добавлений и 293 удалений

Просмотреть файл

@ -12,9 +12,6 @@ This java and Android API and packaging principles were inspired by the https://
1. install visual studio 2022 (with cmake, git, desktop C++)
2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download
(OpenJDK 11.0.15 LTS)
3. Gradle: https://gradle.org/releases/
(v6.9.2)
### Build command
./build.sh **-DOCOS_BUILD_JAVA=ON**

Просмотреть файл

@ -66,7 +66,7 @@ def _mel_filterbank(
# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
@ -109,9 +109,9 @@ class WhisperPrePipeline(torch.nn.Module):
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1], (
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH -
spec_shape[2]), dtype=torch.float)
spec_shape[1],
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0
@ -138,7 +138,7 @@ def _to_onnx_stft(onnx_model):
value=numpy_helper.from_array(np.array([0,
_WhisperHParams.N_FFT // 2, 0,
_WhisperHParams.N_FFT // 2], dtype='int64'),
name='const_14')),
name='const_14')),
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'),
@ -225,9 +225,22 @@ class WhisperDataProcGraph:
return pre_full
def post_processing(self, **kwargs):
skip_special_tokens = kwargs.get('skip_special_tokens', True)
g = SingleOpGraph.build_graph(
"BpeDecoder",
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
skip_special_tokens=True,
cpu_only=True)
return make_onnx_model(g)
skip_special_tokens=skip_special_tokens)
bpenode = g.node[0]
bpenode.input[0] = "generated_ids"
nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpenode]
del g.node[:]
g.node.extend(nodes)
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -60,7 +60,7 @@ def mel_filterbank(
# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
@ -126,3 +126,55 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
remove_unused_initializers(subgraph, top_level_initializers)
def quick_merge(*models, connection_indices=None):
"""
This function merges multiple ONNX models into a single model, without performing any ONNX format checks.
Parameters:
*models (onnx.ModelProto): Varargs parameter representing the ONNX models to be merged.
connection_indices (List[List[int]], optional): A nested list specifying which outputs in one model should connect
to which inputs in the next model, based on their indices.
If not provided, it's assumed that the sequence of outputs in
one model exactly matches the sequence of inputs in the next model.
Returns:
merged_model (onnx.ModelProto): The merged ONNX model.
Raises:
ValueError: If there is any conflict in tensor names, either in initializers or in nodes, including subgraphs.
If there is any conflict in opset versions for the same domain.
"""
merged_graph = models[0].graph
# Dictionary to store unique opsets
opset_imports = {opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
# Iterate over all other models and merge
for model_idx, model in enumerate(models[1:], start=1):
if connection_indices is None:
io_map = [(out.name, in_.name) for out, in_ in zip(models[model_idx - 1].graph.output, model.graph.input)]
else:
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
for out_idx, in_idx in connection_indices[model_idx - 1]]
merged_graph = onnx.compose.merge_graphs(merged_graph, model.graph, io_map)
for opset in model.opset_import:
if not opset.domain:
opset.domain = "ai.onnx"
if opset.domain in opset_imports and opset_imports[opset.domain].version != opset.version:
raise ValueError(f"Conflict in opset versions for domain '{opset.domain}': " +
f"model {model_idx} has version {opset.version}, while previous model has version " +
f"{opset_imports[opset.domain].version}.")
else:
opset_imports[opset.domain] = opset
default_opset = opset_imports.pop("ai.onnx", None)
merged_model = onnx.helper.make_model_gen_version(merged_graph,
opset_imports=[default_opset],
producer_name='ONNX Model Merger')
merged_model.opset_import.extend(opset_imports.values())
return merged_model

Просмотреть файл

@ -8,7 +8,9 @@
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStruct("AudioDecoder", AudioDecoder)
#endif
);

Просмотреть файл

@ -3,7 +3,7 @@
import wave
import unittest
import numpy as np
from onnxruntime_extensions import PyOrtFunction, util, make_onnx_model
from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
import onnx
from onnx import onnx_pb as onnx_proto
@ -12,6 +12,7 @@ from onnx import onnx_pb as onnx_proto
_is_torch_available = False
try:
import torch
_is_torch_available = True
except ImportError:
pass
@ -19,6 +20,7 @@ except ImportError:
_is_librosa_avaliable = False
try:
import librosa
_is_librosa_avaliable = True
except ImportError:
pass
@ -57,7 +59,7 @@ class TestAudio(unittest.TestCase):
audio = wavefile.readframes(samples)
audio_as_np_int16 = np.frombuffer(audio, dtype=np.int16)
audio_as_np_float32 = audio_as_np_int16.astype(np.float32)
max_int16 = 2**15
max_int16 = 2 ** 15
cls.test_pcm = audio_as_np_float32 / max_int16
@staticmethod
@ -99,7 +101,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True)
ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0]
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
@ -109,7 +111,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True)
ortx_stft = OrtPyFunction.from_customop("StftNorm", cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0]
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
@ -125,7 +127,7 @@ class TestAudio(unittest.TestCase):
center=True,
return_complex=True).abs().pow(2).numpy()
audio_pcm = np.expand_dims(self.test_pcm, axis=0)
ortx_stft = PyOrtFunction.from_customop("StftNorm")
ortx_stft = OrtPyFunction.from_customop("StftNorm")
actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400)
actual = actual[0]
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import transformers as _hfts
@ -9,6 +10,7 @@ from packaging import version
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestAutoTokenizer(unittest.TestCase):
def test_t5_tokenizer(self):
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
@ -31,37 +33,77 @@ class TestAutoTokenizer(unittest.TestCase):
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
def test_whisper(self):
def test_whisper_overall(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2*np.pi, 480000).astype(np.float32)
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32))
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&")
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
def test_whisper_audio_decoder(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
raw_audio = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(raw_audio, axis=0))
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1,291 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import io
# Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
import os
import onnx
import re
import torch
import subprocess
import numpy as np
import onnxruntime as ort
from onnx import numpy_helper
from packaging import version
from transformers import WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util
from onnxruntime_extensions.cvt import gen_processing_models
from onnxruntime_extensions import PyOrtFunction, util
from onnxruntime_extensions.cvt import HFTokenizerConverter
# Constants
MODEL_NAME = "openai/whisper-tiny.en"
CACHE_DIR = 'temp_caches_onnx'
OUTPUT_DIR = 'temp_model_onnx'
FINAL_MODEL = "whisper_onnx_tiny_en_fp32_e2e.onnx"
TEST_AUDIO_FILE = util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
# the flags for pre-processing
USE_ONNX_STFT = True
USE_AUDIO_DECODER = True
def check_onnx_version():
if version.parse(ort.__version__) < version.parse("1.16.0"):
raise RuntimeError("ONNXRuntime version must >= 1.16.0")
if not USE_AUDIO_DECODER:
try:
import librosa
except ImportError:
raise ImportError("Please pip3 install librosa without ort-extensions audio codec support.")
def export_onnx_model():
print("Exporting Whisper ONNX model from Huggingface model hub...")
command = ['python', '-m',
'onnxruntime.transformers.models.whisper.convert_to_onnx',
'-m', MODEL_NAME,
'--cache_dir', CACHE_DIR,
'--output', OUTPUT_DIR,
'--precision', 'fp32']
process = subprocess.run(command)
if process.returncode != 0:
raise RuntimeError("Failed to export the core ONNX models.")
# hard-coded audio hyperparameters
# copied from https://github.com/openai/whisper/blob/main/whisper/audio.py#L12
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH
def process_test_file():
if not os.path.exists(TEST_AUDIO_FILE):
raise FileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
_processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
pre_m, post_m = gen_processing_models(_processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
post_kwargs={},
opset=17)
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
class CustomOpStftNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
def main():
check_onnx_version()
export_onnx_model()
log_mel, pre_m, post_m = process_test_file()
@staticmethod
def forward(ctx, audio, n_fft, hop_length, window):
win_length = window.shape[0]
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
return stft.abs() ** 2
class WhisperPrePipeline(torch.nn.Module):
def __init__(self):
super().__init__()
self.window = torch.hann_window(N_FFT)
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
magnitudes = stft_norm[:, :, :-1]
mel_spec = self.mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
spec_min = log_spec.max() - 8.0
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def _to_onnx_stft(onnx_model):
"""Convert custom-op STFT-Norm to ONNX STFT"""
node_idx = 0
new_stft_nodes = []
stft_norm_node = None
for node in onnx_model.graph.node:
if node.op_type == "StftNorm":
stft_norm_node = node
break
node_idx += 1
if stft_norm_node is None:
raise RuntimeError("Cannot find STFTNorm node in the graph")
make_node = onnx.helper.make_node
replaced_nodes = [
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
name='const_14')),
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'),
make_node('STFT',
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
perm=[0, 2, 1, 3]),
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
del onnx_model.graph.node[:]
onnx_model.graph.node.extend(new_stft_nodes)
onnx.checker.check_model(onnx_model)
return onnx_model
def _torch_export(*arg, **kwargs):
with io.BytesIO() as f:
torch.onnx.export(*arg, f, **kwargs)
return onnx.load_from_string(f.getvalue())
def preprocessing(audio_data):
if USE_AUDIO_DECODER:
decoder = PyOrtFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
audio_pcm = torch.from_numpy(decoder(audio_data))
else:
audio_pcm = torch.from_numpy(audio_data)
prep_model_name = 'whisper_pre.onnx'
whisper_processing = WhisperPrePipeline()
model_args = (audio_pcm,)
pre_model = _torch_export(
whisper_processing,
model_args,
input_names=["audio_pcm"],
output_names=["log_mel"],
do_constant_folding=True,
export_params=True,
opset_version=17,
dynamic_axes={
"audio_pcm": {1: "sample_len"},
}
)
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
if USE_ONNX_STFT:
pre_model = _to_onnx_stft(pre_model)
util.remove_unused_initializers(pre_model.graph)
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
if not USE_AUDIO_DECODER:
return pre_f(audio_data)
else:
pre_full = onnx.compose.merge_models(
decoder.onnx_model,
pre_model,
io_map=[("floatPCM", "audio_pcm")])
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
onnx.save_model(pre_f.onnx_model, os.path.join(root_dir, "whisper_codec_pre.onnx"))
result = pre_f(audio_data)
return result
def merge_models(core: str, output_model: str, audio_data):
m_pre_path = os.path.join(root_dir, "whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
m_pre = onnx.load_model(m_pre_path)
m_core = onnx.load_model(core)
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
m2 = onnx.load_model(os.path.join(root_dir, "whisper_post.onnx"))
m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
bpe_decoder_node = m_all.graph.node.pop(-1)
make_node = onnx.helper.make_node
bpe_decoder_node.input.pop(0)
bpe_decoder_node.input.extend(["generated_ids"])
m_all.graph.node.extend([
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpe_decoder_node
])
try:
onnx.save_model(m_all, output_model)
except ValueError:
onnx.save_model(m_all, output_model,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{os.path.basename(output_model)}.data",
convert_attribute=True)
print(f"The final merged model was saved as: {output_model}")
print("Verify the final model...")
m_final = PyOrtFunction.from_model(output_model, cpu_only=True)
output_text = m_final(audio_data,
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
print(output_text)
def postprocessing(token_ids, hf_processor):
fn_decoder = PyOrtFunction.from_customop(
"BpeDecoder",
cvt=HFTokenizerConverter(hf_processor.tokenizer).bpe_decoder,
skip_special_tokens=True,
cpu_only=True)
onnx.save_model(fn_decoder.onnx_model, os.path.join(root_dir, "whisper_post.onnx"))
return fn_decoder(token_ids)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--audio", required=True, help="Path to audio file")
parser.add_argument("-m", "--model", required=True, help="Path to custom export of Whisper with beam search")
args = parser.parse_args()
print("Looking for the exported model...", end='')
onnx_model_name = os.path.basename(args.model)
if not re.search("whisper-.*_beamsearch\.onnx", onnx_model_name):
print("None")
print("Cannot find the whisper beamsearch ONNX models. "
"Please run this script from where Whisper ONNX model was exported. like */onnx_models/openai")
exit(-1)
else:
print(f"{onnx_model_name}")
model_name = "openai/" + onnx_model_name[:-len("_beamsearch.onnx")]
root_dir = os.path.dirname(args.model)
_processor = WhisperProcessor.from_pretrained(model_name)
# The model similar to Huggingface model like:
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
# The onnx model can be generated by the following command:
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
model = PyOrtFunction.from_model(args.model, cpu_only=True)
test_file = util.get_test_data_file(args.audio)
if USE_AUDIO_DECODER:
with open(test_file, "rb") as _f:
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
else:
audio_blob, _ = librosa.load(test_file)
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
log_mel = preprocessing(audio_blob)
print(log_mel.shape)
input_features = log_mel
# similar to:
# generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
ort_outputs = model(input_features,
# Apply core ONNX model
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
token_seq = fn_core(log_mel,
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32),
np.zeros(input_features.shape).astype(np.int32))
generated_ids = ort_outputs[0]
np.asarray([1.0], dtype=np.float32))
print(token_seq.shape)
text = postprocessing(generated_ids[0], _processor)
# Apply post processing
fn_post = OrtPyFunction.from_model(post_m, cpu_only=True)
output_text = fn_post(token_seq)
print(output_text)
# Merge models and save final model
print("Combine the data processing graphs into the ONNX model...")
final_m = util.quick_merge(pre_m, fn_core.onnx_model, post_m)
onnx.save(final_m, FINAL_MODEL)
# Test the final model
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
np.expand_dims(raw_audio, axis=0),
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32))
print(text)
print("build the final model...")
merge_models(args.model, args.model.replace("beamsearch", "all"), audio_blob)
if __name__ == "__main__":
main()