Update whisper model test cases and e2e example (#496)

* Update whisper model test cases and e2e example

* fix unit test on windows

* more refinement

* utest fix
This commit is contained in:
Wenbing Li 2023-07-21 15:27:02 -07:00 коммит произвёл GitHub
Родитель 06d5a8d781
Коммит 62d8598b6b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 209 добавлений и 293 удалений

Просмотреть файл

@ -12,9 +12,6 @@ This java and Android API and packaging principles were inspired by the https://
1. install visual studio 2022 (with cmake, git, desktop C++) 1. install visual studio 2022 (with cmake, git, desktop C++)
2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download 2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download
(OpenJDK 11.0.15 LTS) (OpenJDK 11.0.15 LTS)
3. Gradle: https://gradle.org/releases/
(v6.9.2)
### Build command ### Build command
./build.sh **-DOCOS_BUILD_JAVA=ON** ./build.sh **-DOCOS_BUILD_JAVA=ON**

Просмотреть файл

@ -66,7 +66,7 @@ def _mel_filterbank(
# intersect them with each other and zero # intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right)) fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels]) energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis] fbank *= energy_norm[:, np.newaxis]
return fbank return fbank
@ -109,9 +109,9 @@ class WhisperPrePipeline(torch.nn.Module):
log_spec = torch.maximum(log_spec, spec_min) log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0], padding_spec = torch.ones(spec_shape[0],
spec_shape[1], ( spec_shape[1],
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - _WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
spec_shape[2]), dtype=torch.float) dtype=torch.float)
padding_spec *= spec_min padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2) log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
@ -138,7 +138,7 @@ def _to_onnx_stft(onnx_model):
value=numpy_helper.from_array(np.array([0, value=numpy_helper.from_array(np.array([0,
_WhisperHParams.N_FFT // 2, 0, _WhisperHParams.N_FFT // 2, 0,
_WhisperHParams.N_FFT // 2], dtype='int64'), _WhisperHParams.N_FFT // 2], dtype='int64'),
name='const_14')), name='const_14')),
make_node('Pad', make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'], inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'), outputs=['pad_1_output_0'], mode='reflect'),
@ -225,9 +225,22 @@ class WhisperDataProcGraph:
return pre_full return pre_full
def post_processing(self, **kwargs): def post_processing(self, **kwargs):
skip_special_tokens = kwargs.get('skip_special_tokens', True)
g = SingleOpGraph.build_graph( g = SingleOpGraph.build_graph(
"BpeDecoder", "BpeDecoder",
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder, cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
skip_special_tokens=True, skip_special_tokens=skip_special_tokens)
cpu_only=True)
return make_onnx_model(g) bpenode = g.node[0]
bpenode.input[0] = "generated_ids"
nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpenode]
del g.node[:]
g.node.extend(nodes)
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -60,7 +60,7 @@ def mel_filterbank(
# intersect them with each other and zero # intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right)) fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels]) energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis] fbank *= energy_norm[:, np.newaxis]
return fbank return fbank
@ -126,3 +126,55 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
elif attr.type == onnx.AttributeProto.GRAPHS: elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs: for subgraph in attr.graphs:
remove_unused_initializers(subgraph, top_level_initializers) remove_unused_initializers(subgraph, top_level_initializers)
def quick_merge(*models, connection_indices=None):
"""
This function merges multiple ONNX models into a single model, without performing any ONNX format checks.
Parameters:
*models (onnx.ModelProto): Varargs parameter representing the ONNX models to be merged.
connection_indices (List[List[int]], optional): A nested list specifying which outputs in one model should connect
to which inputs in the next model, based on their indices.
If not provided, it's assumed that the sequence of outputs in
one model exactly matches the sequence of inputs in the next model.
Returns:
merged_model (onnx.ModelProto): The merged ONNX model.
Raises:
ValueError: If there is any conflict in tensor names, either in initializers or in nodes, including subgraphs.
If there is any conflict in opset versions for the same domain.
"""
merged_graph = models[0].graph
# Dictionary to store unique opsets
opset_imports = {opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
# Iterate over all other models and merge
for model_idx, model in enumerate(models[1:], start=1):
if connection_indices is None:
io_map = [(out.name, in_.name) for out, in_ in zip(models[model_idx - 1].graph.output, model.graph.input)]
else:
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
for out_idx, in_idx in connection_indices[model_idx - 1]]
merged_graph = onnx.compose.merge_graphs(merged_graph, model.graph, io_map)
for opset in model.opset_import:
if not opset.domain:
opset.domain = "ai.onnx"
if opset.domain in opset_imports and opset_imports[opset.domain].version != opset.version:
raise ValueError(f"Conflict in opset versions for domain '{opset.domain}': " +
f"model {model_idx} has version {opset.version}, while previous model has version " +
f"{opset_imports[opset.domain].version}.")
else:
opset_imports[opset.domain] = opset
default_opset = opset_imports.pop("ai.onnx", None)
merged_model = onnx.helper.make_model_gen_version(merged_graph,
opset_imports=[default_opset],
producer_name='ONNX Model Merger')
merged_model.opset_import.extend(opset_imports.values())
return merged_model

Просмотреть файл

@ -8,7 +8,9 @@
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& { FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
static OrtOpLoader op_loader( static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS #ifdef ENABLE_DR_LIBS
,
CustomCpuStruct("AudioDecoder", AudioDecoder) CustomCpuStruct("AudioDecoder", AudioDecoder)
#endif #endif
); );

Просмотреть файл

@ -3,7 +3,7 @@
import wave import wave
import unittest import unittest
import numpy as np import numpy as np
from onnxruntime_extensions import PyOrtFunction, util, make_onnx_model from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
import onnx import onnx
from onnx import onnx_pb as onnx_proto from onnx import onnx_pb as onnx_proto
@ -12,6 +12,7 @@ from onnx import onnx_pb as onnx_proto
_is_torch_available = False _is_torch_available = False
try: try:
import torch import torch
_is_torch_available = True _is_torch_available = True
except ImportError: except ImportError:
pass pass
@ -19,6 +20,7 @@ except ImportError:
_is_librosa_avaliable = False _is_librosa_avaliable = False
try: try:
import librosa import librosa
_is_librosa_avaliable = True _is_librosa_avaliable = True
except ImportError: except ImportError:
pass pass
@ -57,7 +59,7 @@ class TestAudio(unittest.TestCase):
audio = wavefile.readframes(samples) audio = wavefile.readframes(samples)
audio_as_np_int16 = np.frombuffer(audio, dtype=np.int16) audio_as_np_int16 = np.frombuffer(audio, dtype=np.int16)
audio_as_np_float32 = audio_as_np_int16.astype(np.float32) audio_as_np_float32 = audio_as_np_int16.astype(np.float32)
max_int16 = 2**15 max_int16 = 2 ** 15
cls.test_pcm = audio_as_np_float32 / max_int16 cls.test_pcm = audio_as_np_float32 / max_int16
@staticmethod @staticmethod
@ -99,7 +101,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32)) expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True) ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400) actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0] actual = actual[0]
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2 actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
@ -109,7 +111,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32)) expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True) ortx_stft = OrtPyFunction.from_customop("StftNorm", cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400) actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0] actual = actual[0]
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3) np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
@ -125,7 +127,7 @@ class TestAudio(unittest.TestCase):
center=True, center=True,
return_complex=True).abs().pow(2).numpy() return_complex=True).abs().pow(2).numpy()
audio_pcm = np.expand_dims(self.test_pcm, axis=0) audio_pcm = np.expand_dims(self.test_pcm, axis=0)
ortx_stft = PyOrtFunction.from_customop("StftNorm") ortx_stft = OrtPyFunction.from_customop("StftNorm")
actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400) actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400)
actual = actual[0] actual = actual[0]
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3) np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import sys
import unittest import unittest
import transformers as _hfts import transformers as _hfts
@ -9,6 +10,7 @@ from packaging import version
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestAutoTokenizer(unittest.TestCase): class TestAutoTokenizer(unittest.TestCase):
def test_t5_tokenizer(self): def test_t5_tokenizer(self):
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512) tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
@ -31,37 +33,77 @@ class TestAutoTokenizer(unittest.TestCase):
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0] actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
np.testing.assert_array_equal(ids[0][:-1], actual_ids) np.testing.assert_array_equal(ids[0][:-1], actual_ids)
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0") def test_whisper_overall(self):
def test_whisper(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en") processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor, pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False}, pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={}) post_kwargs={})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
t = np.linspace(0, 2*np.pi, 480000).astype(np.float32) t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0) simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
log_mel = fn_pre(simaudio) log_mel = fn_pre(simaudio)
self.assertEqual(log_mel.shape, (1, 80, 3000)) self.assertEqual(log_mel.shape, (1, 80, 3000))
fn_post = OrtPyFunction.from_model(post_m) fn_post = OrtPyFunction.from_model(post_m)
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)) rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
self.assertEqual(rel[0], "$%&") self.assertEqual(rel[0], "$%&")
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
def test_whisper_audio_decoder(self): def test_whisper_audio_decoder(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en") processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor, pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True}) pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac') test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
raw_audio = np.fromfile(test_flac_file, dtype=np.uint8) audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
log_mel = fn_pre(np.expand_dims(raw_audio, axis=0)) log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
self.assertEqual(log_mel.shape, (1, 80, 3000)) self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Просмотреть файл

@ -1,291 +1,99 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import argparse
import io # Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
import os import os
import onnx import onnx
import re import subprocess
import torch
import numpy as np import numpy as np
import onnxruntime as ort
from onnx import numpy_helper from packaging import version
from transformers import WhisperProcessor from transformers import WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util
from onnxruntime_extensions.cvt import gen_processing_models
from onnxruntime_extensions import PyOrtFunction, util # Constants
from onnxruntime_extensions.cvt import HFTokenizerConverter MODEL_NAME = "openai/whisper-tiny.en"
CACHE_DIR = 'temp_caches_onnx'
OUTPUT_DIR = 'temp_model_onnx'
FINAL_MODEL = "whisper_onnx_tiny_en_fp32_e2e.onnx"
TEST_AUDIO_FILE = util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
# the flags for pre-processing def check_onnx_version():
USE_ONNX_STFT = True if version.parse(ort.__version__) < version.parse("1.16.0"):
USE_AUDIO_DECODER = True raise RuntimeError("ONNXRuntime version must >= 1.16.0")
if not USE_AUDIO_DECODER: def export_onnx_model():
try: print("Exporting Whisper ONNX model from Huggingface model hub...")
import librosa command = ['python', '-m',
except ImportError: 'onnxruntime.transformers.models.whisper.convert_to_onnx',
raise ImportError("Please pip3 install librosa without ort-extensions audio codec support.") '-m', MODEL_NAME,
'--cache_dir', CACHE_DIR,
'--output', OUTPUT_DIR,
'--precision', 'fp32']
process = subprocess.run(command)
if process.returncode != 0:
raise RuntimeError("Failed to export the core ONNX models.")
# hard-coded audio hyperparameters def process_test_file():
# copied from https://github.com/openai/whisper/blob/main/whisper/audio.py#L12 if not os.path.exists(TEST_AUDIO_FILE):
SAMPLE_RATE = 16000 raise FileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
N_FFT = 400
N_MELS = 80 raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
HOP_LENGTH = 160 _processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
CHUNK_LENGTH = 30 pre_m, post_m = gen_processing_models(_processor,
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
N_FRAMES = N_SAMPLES // HOP_LENGTH post_kwargs={},
opset=17)
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
class CustomOpStftNorm(torch.autograd.Function): def main():
@staticmethod check_onnx_version()
def symbolic(g, self, n_fft, hop_length, window): export_onnx_model()
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) log_mel, pre_m, post_m = process_test_file()
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
@staticmethod # Apply core ONNX model
def forward(ctx, audio, n_fft, hop_length, window): fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
win_length = window.shape[0] token_seq = fn_core(log_mel,
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
return stft.abs() ** 2
class WhisperPrePipeline(torch.nn.Module):
def __init__(self):
super().__init__()
self.window = torch.hann_window(N_FFT)
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
magnitudes = stft_norm[:, :, :-1]
mel_spec = self.mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
spec_min = log_spec.max() - 8.0
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def _to_onnx_stft(onnx_model):
"""Convert custom-op STFT-Norm to ONNX STFT"""
node_idx = 0
new_stft_nodes = []
stft_norm_node = None
for node in onnx_model.graph.node:
if node.op_type == "StftNorm":
stft_norm_node = node
break
node_idx += 1
if stft_norm_node is None:
raise RuntimeError("Cannot find STFTNorm node in the graph")
make_node = onnx.helper.make_node
replaced_nodes = [
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
name='const_14')),
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'),
make_node('STFT',
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
perm=[0, 2, 1, 3]),
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
del onnx_model.graph.node[:]
onnx_model.graph.node.extend(new_stft_nodes)
onnx.checker.check_model(onnx_model)
return onnx_model
def _torch_export(*arg, **kwargs):
with io.BytesIO() as f:
torch.onnx.export(*arg, f, **kwargs)
return onnx.load_from_string(f.getvalue())
def preprocessing(audio_data):
if USE_AUDIO_DECODER:
decoder = PyOrtFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
audio_pcm = torch.from_numpy(decoder(audio_data))
else:
audio_pcm = torch.from_numpy(audio_data)
prep_model_name = 'whisper_pre.onnx'
whisper_processing = WhisperPrePipeline()
model_args = (audio_pcm,)
pre_model = _torch_export(
whisper_processing,
model_args,
input_names=["audio_pcm"],
output_names=["log_mel"],
do_constant_folding=True,
export_params=True,
opset_version=17,
dynamic_axes={
"audio_pcm": {1: "sample_len"},
}
)
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
if USE_ONNX_STFT:
pre_model = _to_onnx_stft(pre_model)
util.remove_unused_initializers(pre_model.graph)
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
if not USE_AUDIO_DECODER:
return pre_f(audio_data)
else:
pre_full = onnx.compose.merge_models(
decoder.onnx_model,
pre_model,
io_map=[("floatPCM", "audio_pcm")])
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
onnx.save_model(pre_f.onnx_model, os.path.join(root_dir, "whisper_codec_pre.onnx"))
result = pre_f(audio_data)
return result
def merge_models(core: str, output_model: str, audio_data):
m_pre_path = os.path.join(root_dir, "whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
m_pre = onnx.load_model(m_pre_path)
m_core = onnx.load_model(core)
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
m2 = onnx.load_model(os.path.join(root_dir, "whisper_post.onnx"))
m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
bpe_decoder_node = m_all.graph.node.pop(-1)
make_node = onnx.helper.make_node
bpe_decoder_node.input.pop(0)
bpe_decoder_node.input.extend(["generated_ids"])
m_all.graph.node.extend([
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpe_decoder_node
])
try:
onnx.save_model(m_all, output_model)
except ValueError:
onnx.save_model(m_all, output_model,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{os.path.basename(output_model)}.data",
convert_attribute=True)
print(f"The final merged model was saved as: {output_model}")
print("Verify the final model...")
m_final = PyOrtFunction.from_model(output_model, cpu_only=True)
output_text = m_final(audio_data,
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
print(output_text)
def postprocessing(token_ids, hf_processor):
fn_decoder = PyOrtFunction.from_customop(
"BpeDecoder",
cvt=HFTokenizerConverter(hf_processor.tokenizer).bpe_decoder,
skip_special_tokens=True,
cpu_only=True)
onnx.save_model(fn_decoder.onnx_model, os.path.join(root_dir, "whisper_post.onnx"))
return fn_decoder(token_ids)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--audio", required=True, help="Path to audio file")
parser.add_argument("-m", "--model", required=True, help="Path to custom export of Whisper with beam search")
args = parser.parse_args()
print("Looking for the exported model...", end='')
onnx_model_name = os.path.basename(args.model)
if not re.search("whisper-.*_beamsearch\.onnx", onnx_model_name):
print("None")
print("Cannot find the whisper beamsearch ONNX models. "
"Please run this script from where Whisper ONNX model was exported. like */onnx_models/openai")
exit(-1)
else:
print(f"{onnx_model_name}")
model_name = "openai/" + onnx_model_name[:-len("_beamsearch.onnx")]
root_dir = os.path.dirname(args.model)
_processor = WhisperProcessor.from_pretrained(model_name)
# The model similar to Huggingface model like:
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
# The onnx model can be generated by the following command:
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
model = PyOrtFunction.from_model(args.model, cpu_only=True)
test_file = util.get_test_data_file(args.audio)
if USE_AUDIO_DECODER:
with open(test_file, "rb") as _f:
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
else:
audio_blob, _ = librosa.load(test_file)
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
log_mel = preprocessing(audio_blob)
print(log_mel.shape)
input_features = log_mel
# similar to:
# generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
ort_outputs = model(input_features,
np.asarray([200], dtype=np.int32), np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32), np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32), np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32), np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32))
np.zeros(input_features.shape).astype(np.int32)) print(token_seq.shape)
generated_ids = ort_outputs[0]
text = postprocessing(generated_ids[0], _processor) # Apply post processing
fn_post = OrtPyFunction.from_model(post_m, cpu_only=True)
output_text = fn_post(token_seq)
print(output_text)
# Merge models and save final model
print("Combine the data processing graphs into the ONNX model...")
final_m = util.quick_merge(pre_m, fn_core.onnx_model, post_m)
onnx.save(final_m, FINAL_MODEL)
# Test the final model
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
np.expand_dims(raw_audio, axis=0),
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32))
print(text) print(text)
print("build the final model...")
merge_models(args.model, args.model.replace("beamsearch", "all"), audio_blob) if __name__ == "__main__":
main()