Update whisper model test cases and e2e example (#496)
* Update whisper model test cases and e2e example * fix unit test on windows * more refinement * utest fix
This commit is contained in:
Родитель
06d5a8d781
Коммит
62d8598b6b
|
@ -12,9 +12,6 @@ This java and Android API and packaging principles were inspired by the https://
|
||||||
1. install visual studio 2022 (with cmake, git, desktop C++)
|
1. install visual studio 2022 (with cmake, git, desktop C++)
|
||||||
2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download
|
2. OpenJDK: https://docs.microsoft.com/en-us/java/openjdk/download
|
||||||
(OpenJDK 11.0.15 LTS)
|
(OpenJDK 11.0.15 LTS)
|
||||||
3. Gradle: https://gradle.org/releases/
|
|
||||||
(v6.9.2)
|
|
||||||
|
|
||||||
|
|
||||||
### Build command
|
### Build command
|
||||||
./build.sh **-DOCOS_BUILD_JAVA=ON**
|
./build.sh **-DOCOS_BUILD_JAVA=ON**
|
||||||
|
|
|
@ -109,9 +109,9 @@ class WhisperPrePipeline(torch.nn.Module):
|
||||||
log_spec = torch.maximum(log_spec, spec_min)
|
log_spec = torch.maximum(log_spec, spec_min)
|
||||||
spec_shape = log_spec.shape
|
spec_shape = log_spec.shape
|
||||||
padding_spec = torch.ones(spec_shape[0],
|
padding_spec = torch.ones(spec_shape[0],
|
||||||
spec_shape[1], (
|
spec_shape[1],
|
||||||
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH -
|
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
|
||||||
spec_shape[2]), dtype=torch.float)
|
dtype=torch.float)
|
||||||
padding_spec *= spec_min
|
padding_spec *= spec_min
|
||||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
@ -225,9 +225,22 @@ class WhisperDataProcGraph:
|
||||||
return pre_full
|
return pre_full
|
||||||
|
|
||||||
def post_processing(self, **kwargs):
|
def post_processing(self, **kwargs):
|
||||||
|
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
||||||
g = SingleOpGraph.build_graph(
|
g = SingleOpGraph.build_graph(
|
||||||
"BpeDecoder",
|
"BpeDecoder",
|
||||||
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
|
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=skip_special_tokens)
|
||||||
cpu_only=True)
|
|
||||||
return make_onnx_model(g)
|
bpenode = g.node[0]
|
||||||
|
bpenode.input[0] = "generated_ids"
|
||||||
|
nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||||
|
bpenode]
|
||||||
|
del g.node[:]
|
||||||
|
g.node.extend(nodes)
|
||||||
|
|
||||||
|
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||||
|
del g.input[:]
|
||||||
|
g.input.extend(inputs)
|
||||||
|
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text']))
|
||||||
|
|
||||||
|
return make_onnx_model(g, opset_version=self.opset_version)
|
||||||
|
|
|
@ -126,3 +126,55 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
|
||||||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||||
for subgraph in attr.graphs:
|
for subgraph in attr.graphs:
|
||||||
remove_unused_initializers(subgraph, top_level_initializers)
|
remove_unused_initializers(subgraph, top_level_initializers)
|
||||||
|
|
||||||
|
|
||||||
|
def quick_merge(*models, connection_indices=None):
|
||||||
|
"""
|
||||||
|
This function merges multiple ONNX models into a single model, without performing any ONNX format checks.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*models (onnx.ModelProto): Varargs parameter representing the ONNX models to be merged.
|
||||||
|
connection_indices (List[List[int]], optional): A nested list specifying which outputs in one model should connect
|
||||||
|
to which inputs in the next model, based on their indices.
|
||||||
|
If not provided, it's assumed that the sequence of outputs in
|
||||||
|
one model exactly matches the sequence of inputs in the next model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
merged_model (onnx.ModelProto): The merged ONNX model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there is any conflict in tensor names, either in initializers or in nodes, including subgraphs.
|
||||||
|
If there is any conflict in opset versions for the same domain.
|
||||||
|
"""
|
||||||
|
|
||||||
|
merged_graph = models[0].graph
|
||||||
|
|
||||||
|
# Dictionary to store unique opsets
|
||||||
|
opset_imports = {opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
|
||||||
|
|
||||||
|
# Iterate over all other models and merge
|
||||||
|
for model_idx, model in enumerate(models[1:], start=1):
|
||||||
|
if connection_indices is None:
|
||||||
|
io_map = [(out.name, in_.name) for out, in_ in zip(models[model_idx - 1].graph.output, model.graph.input)]
|
||||||
|
else:
|
||||||
|
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
|
||||||
|
for out_idx, in_idx in connection_indices[model_idx - 1]]
|
||||||
|
|
||||||
|
merged_graph = onnx.compose.merge_graphs(merged_graph, model.graph, io_map)
|
||||||
|
|
||||||
|
for opset in model.opset_import:
|
||||||
|
if not opset.domain:
|
||||||
|
opset.domain = "ai.onnx"
|
||||||
|
if opset.domain in opset_imports and opset_imports[opset.domain].version != opset.version:
|
||||||
|
raise ValueError(f"Conflict in opset versions for domain '{opset.domain}': " +
|
||||||
|
f"model {model_idx} has version {opset.version}, while previous model has version " +
|
||||||
|
f"{opset_imports[opset.domain].version}.")
|
||||||
|
else:
|
||||||
|
opset_imports[opset.domain] = opset
|
||||||
|
|
||||||
|
default_opset = opset_imports.pop("ai.onnx", None)
|
||||||
|
merged_model = onnx.helper.make_model_gen_version(merged_graph,
|
||||||
|
opset_imports=[default_opset],
|
||||||
|
producer_name='ONNX Model Merger')
|
||||||
|
merged_model.opset_import.extend(opset_imports.values())
|
||||||
|
return merged_model
|
||||||
|
|
|
@ -8,7 +8,9 @@
|
||||||
|
|
||||||
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
|
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
|
||||||
static OrtOpLoader op_loader(
|
static OrtOpLoader op_loader(
|
||||||
|
[]() { return nullptr; }
|
||||||
#ifdef ENABLE_DR_LIBS
|
#ifdef ENABLE_DR_LIBS
|
||||||
|
,
|
||||||
CustomCpuStruct("AudioDecoder", AudioDecoder)
|
CustomCpuStruct("AudioDecoder", AudioDecoder)
|
||||||
#endif
|
#endif
|
||||||
);
|
);
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import wave
|
import wave
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from onnxruntime_extensions import PyOrtFunction, util, make_onnx_model
|
from onnxruntime_extensions import OrtPyFunction, util, make_onnx_model
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
from onnx import onnx_pb as onnx_proto
|
from onnx import onnx_pb as onnx_proto
|
||||||
|
@ -12,6 +12,7 @@ from onnx import onnx_pb as onnx_proto
|
||||||
_is_torch_available = False
|
_is_torch_available = False
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_is_torch_available = True
|
_is_torch_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -19,6 +20,7 @@ except ImportError:
|
||||||
_is_librosa_avaliable = False
|
_is_librosa_avaliable = False
|
||||||
try:
|
try:
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
_is_librosa_avaliable = True
|
_is_librosa_avaliable = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -99,7 +101,7 @@ class TestAudio(unittest.TestCase):
|
||||||
audio_pcm = self.test_pcm
|
audio_pcm = self.test_pcm
|
||||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||||
|
|
||||||
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True)
|
ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
|
||||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||||
actual = actual[0]
|
actual = actual[0]
|
||||||
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
|
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
|
||||||
|
@ -109,7 +111,7 @@ class TestAudio(unittest.TestCase):
|
||||||
audio_pcm = self.test_pcm
|
audio_pcm = self.test_pcm
|
||||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||||
|
|
||||||
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True)
|
ortx_stft = OrtPyFunction.from_customop("StftNorm", cpu_only=True)
|
||||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||||
actual = actual[0]
|
actual = actual[0]
|
||||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||||
|
@ -125,7 +127,7 @@ class TestAudio(unittest.TestCase):
|
||||||
center=True,
|
center=True,
|
||||||
return_complex=True).abs().pow(2).numpy()
|
return_complex=True).abs().pow(2).numpy()
|
||||||
audio_pcm = np.expand_dims(self.test_pcm, axis=0)
|
audio_pcm = np.expand_dims(self.test_pcm, axis=0)
|
||||||
ortx_stft = PyOrtFunction.from_customop("StftNorm")
|
ortx_stft = OrtPyFunction.from_customop("StftNorm")
|
||||||
actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400)
|
actual = ortx_stft(audio_pcm, 400, 160, np.hanning(wlen).astype(np.float32), 400)
|
||||||
actual = actual[0]
|
actual = actual[0]
|
||||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import transformers as _hfts
|
import transformers as _hfts
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ from packaging import version
|
||||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||||
class TestAutoTokenizer(unittest.TestCase):
|
class TestAutoTokenizer(unittest.TestCase):
|
||||||
def test_t5_tokenizer(self):
|
def test_t5_tokenizer(self):
|
||||||
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||||
|
@ -31,8 +33,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
||||||
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
|
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
|
||||||
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
|
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
|
||||||
|
|
||||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
def test_whisper_overall(self):
|
||||||
def test_whisper(self):
|
|
||||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
pre_m, post_m = gen_processing_models(processor,
|
pre_m, post_m = gen_processing_models(processor,
|
||||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||||
|
@ -46,10 +47,9 @@ class TestAutoTokenizer(unittest.TestCase):
|
||||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||||
|
|
||||||
fn_post = OrtPyFunction.from_model(post_m)
|
fn_post = OrtPyFunction.from_model(post_m)
|
||||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32))
|
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :])
|
||||||
self.assertEqual(rel[0], "$%&")
|
self.assertEqual(rel[0], "$%&")
|
||||||
|
|
||||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
|
||||||
def test_whisper_audio_decoder(self):
|
def test_whisper_audio_decoder(self):
|
||||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
pre_m, _ = gen_processing_models(processor,
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
@ -57,11 +57,53 @@ class TestAutoTokenizer(unittest.TestCase):
|
||||||
|
|
||||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||||
raw_audio = np.fromfile(test_flac_file, dtype=np.uint8)
|
audio_data = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||||
log_mel = fn_pre(np.expand_dims(raw_audio, axis=0))
|
log_mel = fn_pre(np.expand_dims(audio_data, axis=0))
|
||||||
|
|
||||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||||
|
def test_ort_stft_consistency(self):
|
||||||
|
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||||
|
|
||||||
|
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||||
|
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||||
|
raw_audio = OrtPyFunction.from_customop(
|
||||||
|
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||||
|
|
||||||
|
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||||
|
expected = input_features['input_features'][0]
|
||||||
|
|
||||||
|
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||||
|
actual = log_mel[0]
|
||||||
|
|
||||||
|
num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05))
|
||||||
|
# ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%.
|
||||||
|
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||||
|
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||||
|
def test_stft_norm_consistency(self):
|
||||||
|
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
pre_m, _ = gen_processing_models(processor,
|
||||||
|
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||||
|
|
||||||
|
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||||
|
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||||
|
raw_audio = OrtPyFunction.from_customop(
|
||||||
|
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||||
|
|
||||||
|
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||||
|
expected = input_features['input_features'][0]
|
||||||
|
|
||||||
|
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||||
|
actual = log_mel[0]
|
||||||
|
|
||||||
|
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||||
|
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -1,291 +1,99 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
import argparse
|
|
||||||
import io
|
# Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
|
||||||
|
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
|
||||||
|
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
|
||||||
import os
|
import os
|
||||||
import onnx
|
import onnx
|
||||||
import re
|
import subprocess
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
from onnx import numpy_helper
|
from packaging import version
|
||||||
from transformers import WhisperProcessor
|
from transformers import WhisperProcessor
|
||||||
|
from onnxruntime_extensions import OrtPyFunction, util
|
||||||
|
from onnxruntime_extensions.cvt import gen_processing_models
|
||||||
|
|
||||||
from onnxruntime_extensions import PyOrtFunction, util
|
# Constants
|
||||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
MODEL_NAME = "openai/whisper-tiny.en"
|
||||||
|
CACHE_DIR = 'temp_caches_onnx'
|
||||||
|
OUTPUT_DIR = 'temp_model_onnx'
|
||||||
|
FINAL_MODEL = "whisper_onnx_tiny_en_fp32_e2e.onnx"
|
||||||
|
TEST_AUDIO_FILE = util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
|
||||||
|
|
||||||
|
|
||||||
# the flags for pre-processing
|
def check_onnx_version():
|
||||||
USE_ONNX_STFT = True
|
if version.parse(ort.__version__) < version.parse("1.16.0"):
|
||||||
USE_AUDIO_DECODER = True
|
raise RuntimeError("ONNXRuntime version must >= 1.16.0")
|
||||||
|
|
||||||
|
|
||||||
if not USE_AUDIO_DECODER:
|
def export_onnx_model():
|
||||||
try:
|
print("Exporting Whisper ONNX model from Huggingface model hub...")
|
||||||
import librosa
|
command = ['python', '-m',
|
||||||
except ImportError:
|
'onnxruntime.transformers.models.whisper.convert_to_onnx',
|
||||||
raise ImportError("Please pip3 install librosa without ort-extensions audio codec support.")
|
'-m', MODEL_NAME,
|
||||||
|
'--cache_dir', CACHE_DIR,
|
||||||
|
'--output', OUTPUT_DIR,
|
||||||
|
'--precision', 'fp32']
|
||||||
|
process = subprocess.run(command)
|
||||||
|
if process.returncode != 0:
|
||||||
|
raise RuntimeError("Failed to export the core ONNX models.")
|
||||||
|
|
||||||
|
|
||||||
# hard-coded audio hyperparameters
|
def process_test_file():
|
||||||
# copied from https://github.com/openai/whisper/blob/main/whisper/audio.py#L12
|
if not os.path.exists(TEST_AUDIO_FILE):
|
||||||
SAMPLE_RATE = 16000
|
raise FileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
|
||||||
N_FFT = 400
|
|
||||||
N_MELS = 80
|
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
|
||||||
HOP_LENGTH = 160
|
_processor = WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
|
||||||
CHUNK_LENGTH = 30
|
pre_m, post_m = gen_processing_models(_processor,
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
|
||||||
N_FRAMES = N_SAMPLES // HOP_LENGTH
|
post_kwargs={},
|
||||||
|
opset=17)
|
||||||
|
|
||||||
|
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||||
|
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
|
||||||
|
|
||||||
|
|
||||||
class CustomOpStftNorm(torch.autograd.Function):
|
def main():
|
||||||
@staticmethod
|
check_onnx_version()
|
||||||
def symbolic(g, self, n_fft, hop_length, window):
|
export_onnx_model()
|
||||||
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
log_mel, pre_m, post_m = process_test_file()
|
||||||
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
|
|
||||||
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
|
||||||
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
|
|
||||||
|
|
||||||
@staticmethod
|
# Apply core ONNX model
|
||||||
def forward(ctx, audio, n_fft, hop_length, window):
|
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
|
||||||
win_length = window.shape[0]
|
token_seq = fn_core(log_mel,
|
||||||
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
|
|
||||||
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
|
||||||
return stft.abs() ** 2
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperPrePipeline(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.window = torch.hann_window(N_FFT)
|
|
||||||
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
|
|
||||||
|
|
||||||
def forward(self, audio_pcm: torch.Tensor):
|
|
||||||
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
|
|
||||||
magnitudes = stft_norm[:, :, :-1]
|
|
||||||
mel_spec = self.mel_filters @ magnitudes
|
|
||||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
||||||
spec_min = log_spec.max() - 8.0
|
|
||||||
log_spec = torch.maximum(log_spec, spec_min)
|
|
||||||
spec_shape = log_spec.shape
|
|
||||||
padding_spec = torch.ones(spec_shape[0],
|
|
||||||
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
|
|
||||||
padding_spec *= spec_min
|
|
||||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
|
||||||
return log_spec
|
|
||||||
|
|
||||||
|
|
||||||
def _to_onnx_stft(onnx_model):
|
|
||||||
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
|
||||||
node_idx = 0
|
|
||||||
new_stft_nodes = []
|
|
||||||
stft_norm_node = None
|
|
||||||
for node in onnx_model.graph.node:
|
|
||||||
if node.op_type == "StftNorm":
|
|
||||||
stft_norm_node = node
|
|
||||||
break
|
|
||||||
node_idx += 1
|
|
||||||
|
|
||||||
if stft_norm_node is None:
|
|
||||||
raise RuntimeError("Cannot find STFTNorm node in the graph")
|
|
||||||
|
|
||||||
make_node = onnx.helper.make_node
|
|
||||||
replaced_nodes = [
|
|
||||||
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
|
||||||
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
|
|
||||||
name='const_14')),
|
|
||||||
make_node('Pad',
|
|
||||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
|
||||||
outputs=['pad_1_output_0'], mode='reflect'),
|
|
||||||
make_node('STFT',
|
|
||||||
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
|
|
||||||
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
|
|
||||||
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
|
|
||||||
perm=[0, 2, 1, 3]),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
|
|
||||||
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
|
|
||||||
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
|
|
||||||
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
|
|
||||||
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
|
|
||||||
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
|
|
||||||
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
|
||||||
name='slice_1'),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
|
|
||||||
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
|
|
||||||
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
|
||||||
name='gather_4', axis=3),
|
|
||||||
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
|
||||||
name='gather_5', axis=3),
|
|
||||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
|
|
||||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
|
|
||||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
|
|
||||||
]
|
|
||||||
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
|
||||||
new_stft_nodes.extend(replaced_nodes)
|
|
||||||
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
|
|
||||||
del onnx_model.graph.node[:]
|
|
||||||
onnx_model.graph.node.extend(new_stft_nodes)
|
|
||||||
onnx.checker.check_model(onnx_model)
|
|
||||||
return onnx_model
|
|
||||||
|
|
||||||
|
|
||||||
def _torch_export(*arg, **kwargs):
|
|
||||||
with io.BytesIO() as f:
|
|
||||||
torch.onnx.export(*arg, f, **kwargs)
|
|
||||||
return onnx.load_from_string(f.getvalue())
|
|
||||||
|
|
||||||
|
|
||||||
def preprocessing(audio_data):
|
|
||||||
if USE_AUDIO_DECODER:
|
|
||||||
decoder = PyOrtFunction.from_customop(
|
|
||||||
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
|
|
||||||
audio_pcm = torch.from_numpy(decoder(audio_data))
|
|
||||||
else:
|
|
||||||
audio_pcm = torch.from_numpy(audio_data)
|
|
||||||
|
|
||||||
prep_model_name = 'whisper_pre.onnx'
|
|
||||||
whisper_processing = WhisperPrePipeline()
|
|
||||||
|
|
||||||
model_args = (audio_pcm,)
|
|
||||||
pre_model = _torch_export(
|
|
||||||
whisper_processing,
|
|
||||||
model_args,
|
|
||||||
input_names=["audio_pcm"],
|
|
||||||
output_names=["log_mel"],
|
|
||||||
do_constant_folding=True,
|
|
||||||
export_params=True,
|
|
||||||
opset_version=17,
|
|
||||||
dynamic_axes={
|
|
||||||
"audio_pcm": {1: "sample_len"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
|
|
||||||
if USE_ONNX_STFT:
|
|
||||||
pre_model = _to_onnx_stft(pre_model)
|
|
||||||
util.remove_unused_initializers(pre_model.graph)
|
|
||||||
|
|
||||||
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
|
|
||||||
if not USE_AUDIO_DECODER:
|
|
||||||
return pre_f(audio_data)
|
|
||||||
else:
|
|
||||||
pre_full = onnx.compose.merge_models(
|
|
||||||
decoder.onnx_model,
|
|
||||||
pre_model,
|
|
||||||
io_map=[("floatPCM", "audio_pcm")])
|
|
||||||
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
|
|
||||||
|
|
||||||
onnx.save_model(pre_f.onnx_model, os.path.join(root_dir, "whisper_codec_pre.onnx"))
|
|
||||||
result = pre_f(audio_data)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def merge_models(core: str, output_model: str, audio_data):
|
|
||||||
m_pre_path = os.path.join(root_dir, "whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
|
|
||||||
m_pre = onnx.load_model(m_pre_path)
|
|
||||||
m_core = onnx.load_model(core)
|
|
||||||
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
|
|
||||||
m2 = onnx.load_model(os.path.join(root_dir, "whisper_post.onnx"))
|
|
||||||
|
|
||||||
m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
|
|
||||||
bpe_decoder_node = m_all.graph.node.pop(-1)
|
|
||||||
make_node = onnx.helper.make_node
|
|
||||||
bpe_decoder_node.input.pop(0)
|
|
||||||
bpe_decoder_node.input.extend(["generated_ids"])
|
|
||||||
m_all.graph.node.extend([
|
|
||||||
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
|
||||||
bpe_decoder_node
|
|
||||||
])
|
|
||||||
try:
|
|
||||||
onnx.save_model(m_all, output_model)
|
|
||||||
except ValueError:
|
|
||||||
onnx.save_model(m_all, output_model,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=f"{os.path.basename(output_model)}.data",
|
|
||||||
convert_attribute=True)
|
|
||||||
print(f"The final merged model was saved as: {output_model}")
|
|
||||||
|
|
||||||
print("Verify the final model...")
|
|
||||||
m_final = PyOrtFunction.from_model(output_model, cpu_only=True)
|
|
||||||
output_text = m_final(audio_data,
|
|
||||||
np.asarray([200], dtype=np.int32),
|
np.asarray([200], dtype=np.int32),
|
||||||
np.asarray([0], dtype=np.int32),
|
np.asarray([0], dtype=np.int32),
|
||||||
np.asarray([2], dtype=np.int32),
|
np.asarray([2], dtype=np.int32),
|
||||||
np.asarray([1], dtype=np.int32),
|
np.asarray([1], dtype=np.int32),
|
||||||
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
|
np.asarray([1.0], dtype=np.float32),
|
||||||
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
|
np.asarray([1.0], dtype=np.float32))
|
||||||
|
print(token_seq.shape)
|
||||||
|
|
||||||
|
# Apply post processing
|
||||||
|
fn_post = OrtPyFunction.from_model(post_m, cpu_only=True)
|
||||||
|
output_text = fn_post(token_seq)
|
||||||
print(output_text)
|
print(output_text)
|
||||||
|
|
||||||
|
# Merge models and save final model
|
||||||
|
print("Combine the data processing graphs into the ONNX model...")
|
||||||
|
final_m = util.quick_merge(pre_m, fn_core.onnx_model, post_m)
|
||||||
|
onnx.save(final_m, FINAL_MODEL)
|
||||||
|
|
||||||
def postprocessing(token_ids, hf_processor):
|
# Test the final model
|
||||||
fn_decoder = PyOrtFunction.from_customop(
|
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
|
||||||
"BpeDecoder",
|
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
|
||||||
cvt=HFTokenizerConverter(hf_processor.tokenizer).bpe_decoder,
|
np.expand_dims(raw_audio, axis=0),
|
||||||
skip_special_tokens=True,
|
|
||||||
cpu_only=True)
|
|
||||||
|
|
||||||
onnx.save_model(fn_decoder.onnx_model, os.path.join(root_dir, "whisper_post.onnx"))
|
|
||||||
return fn_decoder(token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("-a", "--audio", required=True, help="Path to audio file")
|
|
||||||
parser.add_argument("-m", "--model", required=True, help="Path to custom export of Whisper with beam search")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print("Looking for the exported model...", end='')
|
|
||||||
onnx_model_name = os.path.basename(args.model)
|
|
||||||
if not re.search("whisper-.*_beamsearch\.onnx", onnx_model_name):
|
|
||||||
print("None")
|
|
||||||
print("Cannot find the whisper beamsearch ONNX models. "
|
|
||||||
"Please run this script from where Whisper ONNX model was exported. like */onnx_models/openai")
|
|
||||||
exit(-1)
|
|
||||||
else:
|
|
||||||
print(f"{onnx_model_name}")
|
|
||||||
|
|
||||||
model_name = "openai/" + onnx_model_name[:-len("_beamsearch.onnx")]
|
|
||||||
root_dir = os.path.dirname(args.model)
|
|
||||||
|
|
||||||
_processor = WhisperProcessor.from_pretrained(model_name)
|
|
||||||
# The model similar to Huggingface model like:
|
|
||||||
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
|
|
||||||
# The onnx model can be generated by the following command:
|
|
||||||
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
|
|
||||||
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
|
|
||||||
model = PyOrtFunction.from_model(args.model, cpu_only=True)
|
|
||||||
|
|
||||||
test_file = util.get_test_data_file(args.audio)
|
|
||||||
if USE_AUDIO_DECODER:
|
|
||||||
with open(test_file, "rb") as _f:
|
|
||||||
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
|
|
||||||
else:
|
|
||||||
audio_blob, _ = librosa.load(test_file)
|
|
||||||
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
|
|
||||||
|
|
||||||
log_mel = preprocessing(audio_blob)
|
|
||||||
print(log_mel.shape)
|
|
||||||
|
|
||||||
input_features = log_mel
|
|
||||||
# similar to:
|
|
||||||
# generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
|
|
||||||
ort_outputs = model(input_features,
|
|
||||||
np.asarray([200], dtype=np.int32),
|
np.asarray([200], dtype=np.int32),
|
||||||
np.asarray([0], dtype=np.int32),
|
np.asarray([0], dtype=np.int32),
|
||||||
np.asarray([2], dtype=np.int32),
|
np.asarray([2], dtype=np.int32),
|
||||||
np.asarray([1], dtype=np.int32),
|
np.asarray([1], dtype=np.int32),
|
||||||
np.asarray([1.0], dtype=np.float32),
|
np.asarray([1.0], dtype=np.float32),
|
||||||
np.asarray([1.0], dtype=np.float32),
|
np.asarray([1.0], dtype=np.float32))
|
||||||
np.zeros(input_features.shape).astype(np.int32))
|
|
||||||
generated_ids = ort_outputs[0]
|
|
||||||
|
|
||||||
text = postprocessing(generated_ids[0], _processor)
|
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
print("build the final model...")
|
|
||||||
merge_models(args.model, args.model.replace("beamsearch", "all"), audio_blob)
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче