onnxruntime-extensions/tutorials/whisper_e2e.py

136 строки
5.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
import os
import onnx
import subprocess
import numpy as np
import onnxruntime as ort
from packaging import version
from transformers import WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util
from onnxruntime_extensions.cvt import gen_processing_models
# Constants
MODEL_NAME = "openai/whisper-tiny.en"
CACHE_DIR = 'temp_caches_onnx'
OUTPUT_DIR = 'temp_model_onnx'
FINAL_MODEL = "whisper_onnx_tiny_en_fp32_e2e.onnx"
TEST_AUDIO_FILE = util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
def check_onnx_version():
if version.parse(ort.__version__) < version.parse("1.16.0"):
raise RuntimeError("ONNXRuntime version must >= 1.16.0")
def export_onnx_model():
print("Exporting Whisper ONNX model from Huggingface model hub...")
command = ['python', '-m',
'onnxruntime.transformers.models.whisper.convert_to_onnx',
'-m', MODEL_NAME,
'--cache_dir', CACHE_DIR,
'--output', OUTPUT_DIR,
'--precision', 'fp32']
process = subprocess.run(command)
if process.returncode != 0:
raise RuntimeError("Failed to export the core ONNX models.")
def process_test_file():
if not os.path.exists(TEST_AUDIO_FILE):
raise FileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
_processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
pre_m, post_m = gen_processing_models(_processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
post_kwargs={},
opset=17)
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
def get_model_inputs(ort_session, audio_data):
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
print(ort_names)
inputs = [
audio_data, # audio_stream/input_features
np.asarray([200], dtype=np.int32), # max_length
np.asarray([0], dtype=np.int32), # min_length
np.asarray([2], dtype=np.int32), # num_beams
np.asarray([1], dtype=np.int32), # num_return_sequences
np.asarray([1.0], dtype=np.float32), # length_penalty
np.asarray([1.0], dtype=np.float32), # repetition_penalty
]
required_input_names = {"audio_stream", "input_features", "max_length", "min_length", "num_beams",
"num_return_sequences", "length_penalty", "repetition_penalty"}
# Add optional inputs if present in model
batch_size = 1
N_MELS = 80
N_FRAMES = 3000
vocab_size = 51864 if ".en" in MODEL_NAME else 51865
decoder_start_token_id = 50257 if ".en" in MODEL_NAME else 50258
for name in ort_names:
if name in required_input_names:
continue
elif name == "vocab_mask":
inputs.append(np.ones(vocab_size, dtype=np.int32))
elif name == "prefix_vocab_mask":
inputs.append(np.ones((batch_size, vocab_size), dtype=np.int32))
elif name == "attention_mask":
# For older ORT versions that have the dummy attention mask input for the beam search op
inputs.append(np.zeros((batch_size, N_MELS, N_FRAMES), dtype=np.int32))
elif name == "decoder_input_ids":
inputs.append(np.array([[decoder_start_token_id]], dtype=np.int32))
elif name == "logits_processor":
inputs.append(np.array([1], dtype=np.int32))
else:
raise NotImplementedError(f"'{name}' input is not supported")
return inputs
def main():
check_onnx_version()
export_onnx_model()
log_mel, pre_m, post_m = process_test_file()
# Apply core ONNX model
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
fn_core_ort_session = fn_core._ensure_ort_session()
model_inputs = get_model_inputs(fn_core_ort_session, log_mel)
token_seq = fn_core(*model_inputs)
print(token_seq.shape)
# Apply post processing
fn_post = OrtPyFunction.from_model(post_m, cpu_only=True)
output_text = fn_post(token_seq)
print(output_text)
# Merge models and save final model
print("Combine the data processing graphs into the ONNX model...")
final_m = util.quick_merge(pre_m, fn_core.onnx_model, post_m)
onnx.save(final_m, FINAL_MODEL)
# Test the final model
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
raw_audio = np.expand_dims(raw_audio, axis=0)
e2e_model = OrtPyFunction.from_model(final_m, cpu_only=True)
e2e_model_ort_session = e2e_model._ensure_ort_session()
model_inputs = get_model_inputs(e2e_model_ort_session, raw_audio)
text = e2e_model(*model_inputs)
print(text)
if __name__ == "__main__":
main()