Update whisper model test cases and e2e example (#496)
* Update whisper model test cases and e2e example * fix unit test on windows * more refinement * utest fix
This commit is contained in:
Родитель
06d5a8d781
Коммит
62d8598b6b
|
@ -12,9 +12,6 @@ This java and Android API and packaging principles were inspired by the https://
|
|||
1. install visual studio 2022 (with cmake, git, desktop C++)
|
||||
2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download
|
||||
(OpenJDK 11.0.15 LTS)
|
||||
3. Gradle: https://gradle.org/releases/
|
||||
(v6.9.2)
|
||||
|
||||
|
||||
### Build command
|
||||
./build.sh **-DOCOS_BUILD_JAVA=ON**
|
||||
|
|
|
@ -66,7 +66,7 @@ def _mel_filterbank(
|
|||
# intersect them with each other and zero
|
||||
fbank[i] = np.maximum(0, np.minimum(left, right))
|
||||
|
||||
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
|
||||
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
|
||||
fbank *= energy_norm[:, np.newaxis]
|
||||
return fbank
|
||||
|
||||
|
@ -109,9 +109,9 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
log_spec = torch.maximum(log_spec, spec_min)
|
||||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1], (
|
||||
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH -
|
||||
spec_shape[2]), dtype=torch.float)
|
||||
spec_shape[1],
|
||||
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
|
||||
dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
@ -138,7 +138,7 @@ def _to_onnx_stft(onnx_model):
|
|||
value=numpy_helper.from_array(np.array([0,
|
||||
_WhisperHParams.N_FFT // 2, 0,
|
||||
_WhisperHParams.N_FFT // 2], dtype='int64'),
|
||||
name='const_14')),
|
||||
name='const_14')),
|
||||
make_node('Pad',
|
||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
||||
outputs=['pad_1_output_0'], mode='reflect'),
|
||||
|
@ -225,9 +225,22 @@ class WhisperDataProcGraph:
|
|||
return pre_full
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
||||
g = SingleOpGraph.build_graph(
|
||||
"BpeDecoder",
|
||||
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
|
||||
skip_special_tokens=True,
|
||||
cpu_only=True)
|
||||
return make_onnx_model(g)
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
bpenode = g.node[0]
|
||||
bpenode.input[0] = "generated_ids"
|
||||
nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||
bpenode]
|
||||
del g.node[:]
|
||||
g.node.extend(nodes)
|
||||
|
||||
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||
del g.input[:]
|
||||
g.input.extend(inputs)
|
||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
|
||||
|
||||
return make_onnx_model(g, opset_version=self.opset_version)
|
||||
|
|
|
@ -60,7 +60,7 @@ def mel_filterbank(
|
|||
# intersect them with each other and zero
|
||||
fbank[i] = np.maximum(0, np.minimum(left, right))
|
||||
|
||||
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
|
||||
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
|
||||
fbank *= energy_norm[:, np.newaxis]
|
||||
return fbank
|
||||
|
||||
|
@ -126,3 +126,55 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
|
|||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for subgraph in attr.graphs:
|
||||
remove_unused_initializers(subgraph, top_level_initializers)
|
||||
|
||||
|
||||
def quick_merge(*models, connection_indices=None):
|
||||
"""
|
||||
This function merges multiple ONNX models into a single model, without performing any ONNX format checks.
|
||||
|
||||
Parameters:
|
||||
*models (onnx.ModelProto): Varargs parameter representing the ONNX models to be merged.
|
||||
connection_indices (List[List[int]], optional): A nested list specifying which outputs in one model should connect
|
||||
to which inputs in the next model, based on their indices.
|
||||
If not provided, it's assumed that the sequence of outputs in
|
||||
one model exactly matches the sequence of inputs in the next model.
|
||||
|
||||
Returns:
|
||||
merged_model (onnx.ModelProto): The merged ONNX model.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is any conflict in tensor names, either in initializers or in nodes, including subgraphs.
|
||||
If there is any conflict in opset versions for the same domain.
|
||||
"""
|
||||
|
||||
merged_graph = models[0].graph
|
||||
|
||||
# Dictionary to store unique opsets
|
||||
opset_imports = {opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
|
||||
|
||||
# Iterate over all other models and merge
|
||||
for model_idx, model in enumerate(models[1:], start=1):
|
||||
if connection_indices is None:
|
||||
io_map = [(out.name, in_.name) for out, in_ in zip(models[model_idx - 1].graph.output, model.graph.input)]
|
||||
else:
|
||||
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
|
||||
for out_idx, in_idx in connection_indices[model_idx - 1]]
|
||||
|
||||
merged_graph = onnx.compose.merge_graphs(merged_graph, model.graph, io_map)
|
||||
|
||||
for opset in model.opset_import:
|
||||
if not opset.domain:
|
||||
opset.domain = "ai.onnx"
|
||||
if opset.domain in opset_imports and opset_imports[opset.domain].version != opset.version:
|
||||
raise ValueError(f"Conflict in opset versions for domain '{opset.domain}': " +
|
||||
f"model {model_idx} has version {opset.version}, while previous model has version " +
|
||||
f"{opset_imports[opset.domain].version}.")
|
||||
else:
|
||||
opset_imports[opset.domain] = opset
|
||||
|
||||
default_opset = opset_imports.pop("ai.onnx", None)
|
||||
merged_model = onnx.helper.make_model_gen_version(merged_graph,
|
||||
opset_imports=[default_opset],
|
||||
producer_name='ONNX Model Merger')
|
||||
merged_model.opset_import.extend(opset_imports.values())
|
||||
return merged_model
|
||||
|
|
|
@ -8,7 +8,9 @@
|
|||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(
|
||||
[]() { return nullptr; }
|
||||
#ifdef ENABLE_DR_LIBS
|
||||
,
|
||||
CustomCpuStruct("AudioDecoder", AudioDecoder)
|
||||
#endif
|
||||
);
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import wave
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import PyOrtFunction, util, make_onnx_model
|
||||
from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
|
||||
|
||||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
@ -12,6 +12,7 @@ from onnx import onnx_pb as onnx_proto
|
|||
_is_torch_available = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
_is_torch_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -19,6 +20,7 @@ except ImportError:
|
|||
_is_librosa_avaliable = False
|
||||
try:
|
||||
import librosa
|
||||
|
||||
_is_librosa_avaliable = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -57,7 +59,7 @@ class TestAudio(unittest.TestCase):
|
|||
audio = wavefile.readframes(samples)
|
||||
audio_as_np_int16 = np.frombuffer(audio, dtype=np.int16)
|
||||
audio_as_np_float32 = audio_as_np_int16.astype(np.float32)
|
||||
max_int16 = 2**15
|
||||
max_int16 = 2 ** 15
|
||||
cls.test_pcm = audio_as_np_float32 / max_int16
|
||||
|
||||
@staticmethod
|
||||
|
@ -99,7 +101,7 @@ class TestAudio(unittest.TestCase):
|
|||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
||||
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True)
|
||||
ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
|
||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
|
||||
|
@ -109,7 +111,7 @@ class TestAudio(unittest.TestCase):
|
|||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
||||
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True)
|
||||
ortx_stft = OrtPyFunction.from_customop("StftNorm", cpu_only=True)
|
||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||
|
@ -125,7 +127,7 @@ class TestAudio(unittest.TestCase):
|
|||
center=True,
|
||||
return_complex=True).abs().pow(2).numpy()
|
||||
audio_pcm = np.expand_dims(self.test_pcm, axis=0)
|
||||
ortx_stft = PyOrtFunction.from_customop("StftNorm")
|
||||
ortx_stft = OrtPyFunction.from_customop("StftNorm")
|
||||
actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
import unittest
|
||||
import transformers as _hfts
|
||||
|
||||
|
@ -9,6 +10,7 @@ from packaging import version
|
|||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
class TestAutoTokenizer(unittest.TestCase):
|
||||
def test_t5_tokenizer(self):
|
||||
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||
|
@ -31,37 +33,77 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
|
||||
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
def test_whisper(self):
|
||||
def test_whisper_overall(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, post_m = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
t = np.linspace(0, 2*np.pi, 480000).astype(np.float32)
|
||||
t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32)
|
||||
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
||||
log_mel = fn_pre(simaudio)
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
fn_post = OrtPyFunction.from_model(post_m)
|
||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32))
|
||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
||||
self.assertEqual(rel[0], "$%&")
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
def test_whisper_audio_decoder(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||
raw_audio = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||
log_mel = fn_pre(np.expand_dims(raw_audio, axis=0))
|
||||
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_ort_stft_consistency(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
||||
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
||||
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_stft_norm_consistency(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,291 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import argparse
|
||||
import io
|
||||
|
||||
# Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
|
||||
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
|
||||
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
|
||||
import os
|
||||
import onnx
|
||||
import re
|
||||
import torch
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from onnx import numpy_helper
|
||||
from packaging import version
|
||||
from transformers import WhisperProcessor
|
||||
from onnxruntime_extensions import OrtPyFunction, util
|
||||
from onnxruntime_extensions.cvt import gen_processing_models
|
||||
|
||||
from onnxruntime_extensions import PyOrtFunction, util
|
||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
||||
# Constants
|
||||
MODEL_NAME = "openai/whisper-tiny.en"
|
||||
CACHE_DIR = 'temp_caches_onnx'
|
||||
OUTPUT_DIR = 'temp_model_onnx'
|
||||
FINAL_MODEL = "whisper_onnx_tiny_en_fp32_e2e.onnx"
|
||||
TEST_AUDIO_FILE = util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
|
||||
|
||||
|
||||
# the flags for pre-processing
|
||||
USE_ONNX_STFT = True
|
||||
USE_AUDIO_DECODER = True
|
||||
def check_onnx_version():
|
||||
if version.parse(ort.__version__) < version.parse("1.16.0"):
|
||||
raise RuntimeError("ONNXRuntime version must >= 1.16.0")
|
||||
|
||||
|
||||
if not USE_AUDIO_DECODER:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
raise ImportError("Please pip3 install librosa without ort-extensions audio codec support.")
|
||||
def export_onnx_model():
|
||||
print("Exporting Whisper ONNX model from Huggingface model hub...")
|
||||
command = ['python', '-m',
|
||||
'onnxruntime.transformers.models.whisper.convert_to_onnx',
|
||||
'-m', MODEL_NAME,
|
||||
'--cache_dir', CACHE_DIR,
|
||||
'--output', OUTPUT_DIR,
|
||||
'--precision', 'fp32']
|
||||
process = subprocess.run(command)
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError("Failed to export the core ONNX models.")
|
||||
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
# copied from https://github.com/openai/whisper/blob/main/whisper/audio.py#L12
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = N_SAMPLES // HOP_LENGTH
|
||||
def process_test_file():
|
||||
if not os.path.exists(TEST_AUDIO_FILE):
|
||||
raise FileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
|
||||
|
||||
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
|
||||
_processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
|
||||
pre_m, post_m = gen_processing_models(_processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
|
||||
post_kwargs={},
|
||||
opset=17)
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
|
||||
|
||||
|
||||
class CustomOpStftNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, self, n_fft, hop_length, window):
|
||||
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
|
||||
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
|
||||
def main():
|
||||
check_onnx_version()
|
||||
export_onnx_model()
|
||||
log_mel, pre_m, post_m = process_test_file()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, audio, n_fft, hop_length, window):
|
||||
win_length = window.shape[0]
|
||||
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
|
||||
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
||||
return stft.abs() ** 2
|
||||
|
||||
|
||||
class WhisperPrePipeline(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.window = torch.hann_window(N_FFT)
|
||||
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
|
||||
|
||||
def forward(self, audio_pcm: torch.Tensor):
|
||||
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
|
||||
magnitudes = stft_norm[:, :, :-1]
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
spec_min = log_spec.max() - 8.0
|
||||
log_spec = torch.maximum(log_spec, spec_min)
|
||||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
def _to_onnx_stft(onnx_model):
|
||||
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
||||
node_idx = 0
|
||||
new_stft_nodes = []
|
||||
stft_norm_node = None
|
||||
for node in onnx_model.graph.node:
|
||||
if node.op_type == "StftNorm":
|
||||
stft_norm_node = node
|
||||
break
|
||||
node_idx += 1
|
||||
|
||||
if stft_norm_node is None:
|
||||
raise RuntimeError("Cannot find STFTNorm node in the graph")
|
||||
|
||||
make_node = onnx.helper.make_node
|
||||
replaced_nodes = [
|
||||
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
||||
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
|
||||
name='const_14')),
|
||||
make_node('Pad',
|
||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
||||
outputs=['pad_1_output_0'], mode='reflect'),
|
||||
make_node('STFT',
|
||||
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
|
||||
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
|
||||
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
|
||||
perm=[0, 2, 1, 3]),
|
||||
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
|
||||
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
|
||||
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
|
||||
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
|
||||
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
|
||||
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
|
||||
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
||||
name='slice_1'),
|
||||
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
|
||||
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
||||
name='gather_4', axis=3),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
||||
name='gather_5', axis=3),
|
||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
|
||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
|
||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
|
||||
]
|
||||
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
||||
new_stft_nodes.extend(replaced_nodes)
|
||||
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
|
||||
del onnx_model.graph.node[:]
|
||||
onnx_model.graph.node.extend(new_stft_nodes)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def _torch_export(*arg, **kwargs):
|
||||
with io.BytesIO() as f:
|
||||
torch.onnx.export(*arg, f, **kwargs)
|
||||
return onnx.load_from_string(f.getvalue())
|
||||
|
||||
|
||||
def preprocessing(audio_data):
|
||||
if USE_AUDIO_DECODER:
|
||||
decoder = PyOrtFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
|
||||
audio_pcm = torch.from_numpy(decoder(audio_data))
|
||||
else:
|
||||
audio_pcm = torch.from_numpy(audio_data)
|
||||
|
||||
prep_model_name = 'whisper_pre.onnx'
|
||||
whisper_processing = WhisperPrePipeline()
|
||||
|
||||
model_args = (audio_pcm,)
|
||||
pre_model = _torch_export(
|
||||
whisper_processing,
|
||||
model_args,
|
||||
input_names=["audio_pcm"],
|
||||
output_names=["log_mel"],
|
||||
do_constant_folding=True,
|
||||
export_params=True,
|
||||
opset_version=17,
|
||||
dynamic_axes={
|
||||
"audio_pcm": {1: "sample_len"},
|
||||
}
|
||||
)
|
||||
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
|
||||
if USE_ONNX_STFT:
|
||||
pre_model = _to_onnx_stft(pre_model)
|
||||
util.remove_unused_initializers(pre_model.graph)
|
||||
|
||||
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
|
||||
if not USE_AUDIO_DECODER:
|
||||
return pre_f(audio_data)
|
||||
else:
|
||||
pre_full = onnx.compose.merge_models(
|
||||
decoder.onnx_model,
|
||||
pre_model,
|
||||
io_map=[("floatPCM", "audio_pcm")])
|
||||
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
|
||||
|
||||
onnx.save_model(pre_f.onnx_model, os.path.join(root_dir, "whisper_codec_pre.onnx"))
|
||||
result = pre_f(audio_data)
|
||||
return result
|
||||
|
||||
|
||||
def merge_models(core: str, output_model: str, audio_data):
|
||||
m_pre_path = os.path.join(root_dir, "whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
|
||||
m_pre = onnx.load_model(m_pre_path)
|
||||
m_core = onnx.load_model(core)
|
||||
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
|
||||
m2 = onnx.load_model(os.path.join(root_dir, "whisper_post.onnx"))
|
||||
|
||||
m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
|
||||
bpe_decoder_node = m_all.graph.node.pop(-1)
|
||||
make_node = onnx.helper.make_node
|
||||
bpe_decoder_node.input.pop(0)
|
||||
bpe_decoder_node.input.extend(["generated_ids"])
|
||||
m_all.graph.node.extend([
|
||||
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||
bpe_decoder_node
|
||||
])
|
||||
try:
|
||||
onnx.save_model(m_all, output_model)
|
||||
except ValueError:
|
||||
onnx.save_model(m_all, output_model,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"{os.path.basename(output_model)}.data",
|
||||
convert_attribute=True)
|
||||
print(f"The final merged model was saved as: {output_model}")
|
||||
|
||||
print("Verify the final model...")
|
||||
m_final = PyOrtFunction.from_model(output_model, cpu_only=True)
|
||||
output_text = m_final(audio_data,
|
||||
np.asarray([200], dtype=np.int32),
|
||||
np.asarray([0], dtype=np.int32),
|
||||
np.asarray([2], dtype=np.int32),
|
||||
np.asarray([1], dtype=np.int32),
|
||||
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
|
||||
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
|
||||
print(output_text)
|
||||
|
||||
|
||||
def postprocessing(token_ids, hf_processor):
|
||||
fn_decoder = PyOrtFunction.from_customop(
|
||||
"BpeDecoder",
|
||||
cvt=HFTokenizerConverter(hf_processor.tokenizer).bpe_decoder,
|
||||
skip_special_tokens=True,
|
||||
cpu_only=True)
|
||||
|
||||
onnx.save_model(fn_decoder.onnx_model, os.path.join(root_dir, "whisper_post.onnx"))
|
||||
return fn_decoder(token_ids)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-a", "--audio", required=True, help="Path to audio file")
|
||||
parser.add_argument("-m", "--model", required=True, help="Path to custom export of Whisper with beam search")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Looking for the exported model...", end='')
|
||||
onnx_model_name = os.path.basename(args.model)
|
||||
if not re.search("whisper-.*_beamsearch\.onnx", onnx_model_name):
|
||||
print("None")
|
||||
print("Cannot find the whisper beamsearch ONNX models. "
|
||||
"Please run this script from where Whisper ONNX model was exported. like */onnx_models/openai")
|
||||
exit(-1)
|
||||
else:
|
||||
print(f"{onnx_model_name}")
|
||||
|
||||
model_name = "openai/" + onnx_model_name[:-len("_beamsearch.onnx")]
|
||||
root_dir = os.path.dirname(args.model)
|
||||
|
||||
_processor = WhisperProcessor.from_pretrained(model_name)
|
||||
# The model similar to Huggingface model like:
|
||||
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
||||
|
||||
# The onnx model can be generated by the following command:
|
||||
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
|
||||
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
|
||||
model = PyOrtFunction.from_model(args.model, cpu_only=True)
|
||||
|
||||
test_file = util.get_test_data_file(args.audio)
|
||||
if USE_AUDIO_DECODER:
|
||||
with open(test_file, "rb") as _f:
|
||||
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
|
||||
else:
|
||||
audio_blob, _ = librosa.load(test_file)
|
||||
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
|
||||
|
||||
log_mel = preprocessing(audio_blob)
|
||||
print(log_mel.shape)
|
||||
|
||||
input_features = log_mel
|
||||
# similar to:
|
||||
# generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
|
||||
ort_outputs = model(input_features,
|
||||
# Apply core ONNX model
|
||||
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
|
||||
token_seq = fn_core(log_mel,
|
||||
np.asarray([200], dtype=np.int32),
|
||||
np.asarray([0], dtype=np.int32),
|
||||
np.asarray([2], dtype=np.int32),
|
||||
np.asarray([1], dtype=np.int32),
|
||||
np.asarray([1.0], dtype=np.float32),
|
||||
np.asarray([1.0], dtype=np.float32),
|
||||
np.zeros(input_features.shape).astype(np.int32))
|
||||
generated_ids = ort_outputs[0]
|
||||
np.asarray([1.0], dtype=np.float32))
|
||||
print(token_seq.shape)
|
||||
|
||||
text = postprocessing(generated_ids[0], _processor)
|
||||
# Apply post processing
|
||||
fn_post = OrtPyFunction.from_model(post_m, cpu_only=True)
|
||||
output_text = fn_post(token_seq)
|
||||
print(output_text)
|
||||
|
||||
# Merge models and save final model
|
||||
print("Combine the data processing graphs into the ONNX model...")
|
||||
final_m = util.quick_merge(pre_m, fn_core.onnx_model, post_m)
|
||||
onnx.save(final_m, FINAL_MODEL)
|
||||
|
||||
# Test the final model
|
||||
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
|
||||
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
|
||||
np.expand_dims(raw_audio, axis=0),
|
||||
np.asarray([200], dtype=np.int32),
|
||||
np.asarray([0], dtype=np.int32),
|
||||
np.asarray([2], dtype=np.int32),
|
||||
np.asarray([1], dtype=np.int32),
|
||||
np.asarray([1.0], dtype=np.float32),
|
||||
np.asarray([1.0], dtype=np.float32))
|
||||
print(text)
|
||||
|
||||
print("build the final model...")
|
||||
merge_models(args.model, args.model.replace("beamsearch", "all"), audio_blob)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче