Add a merge step in whisper end-to-end script and fixed some issues (#399)
* add merged models in whisper model * verify the final model
This commit is contained in:
Родитель
77e63c8845
Коммит
711774db6b
|
@ -24,6 +24,9 @@ def get_opset_version_from_ort():
|
|||
}
|
||||
|
||||
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
||||
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
|
||||
if ort_ver_string > max_ver:
|
||||
ort_ver_string = max_ver
|
||||
return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
|
||||
|
||||
|
||||
|
@ -41,8 +44,6 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
|||
|
||||
class OrtPyFunction:
|
||||
|
||||
__name__ = 'OrtPyFunction'
|
||||
|
||||
@classmethod
|
||||
def get_ort_session_options(cls):
|
||||
# ONNXRuntime has an issue to support reusing the SessionOptions object.
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import io
|
||||
import onnx
|
||||
import numpy
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from onnx import numpy_helper
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
# from onnx import compose
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import PyOrtFunction, util
|
||||
from onnxruntime_extensions import PyOrtFunction, pnp, util, optimize_model
|
||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
||||
|
||||
|
||||
# the flags for pre-processing
|
||||
USE_ONNX_STFT = False
|
||||
USE_ONNX_STFT = True
|
||||
USE_ONNX_COREMODEL = True
|
||||
USE_AUDIO_DECODER = True
|
||||
|
||||
|
@ -51,22 +53,6 @@ class CustomOpStftNorm(torch.autograd.Function):
|
|||
return stft.abs() ** 2
|
||||
|
||||
|
||||
class CustomOpStft(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, self, n_fft, hop_length, window):
|
||||
t_frame_step = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
|
||||
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
return g.op("STFT", self, t_frame_step, window, t_frame_size)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, audio, n_fft, hop_length, window):
|
||||
win_length = window.shape[0]
|
||||
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
|
||||
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
||||
stft = torch.permute(stft, (0, 2, 1))
|
||||
return torch.view_as_real(stft)
|
||||
|
||||
|
||||
class WhisperPrePipeline(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -74,75 +60,160 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
|
||||
|
||||
def forward(self, audio_pcm: torch.Tensor):
|
||||
if USE_AUDIO_DECODER:
|
||||
audio_pcm = audio_pcm.squeeze(0)
|
||||
|
||||
pad_len = N_SAMPLES - audio_pcm.shape[0]
|
||||
audio_pcm = torch.nn.functional.pad(audio_pcm, (0, pad_len), mode='constant', value=0)
|
||||
audio_pcm = audio_pcm.unsqueeze(0)
|
||||
|
||||
if USE_ONNX_STFT:
|
||||
stft = CustomOpStft.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
|
||||
stft_norm = stft[..., 0] ** 2 + stft[..., 1] ** 2
|
||||
stft_norm = torch.permute(stft_norm, (0, 2, 1))
|
||||
else:
|
||||
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
|
||||
|
||||
stft_norm.squeeze_(0)
|
||||
magnitudes = stft_norm[:, :-1]
|
||||
magnitudes = stft_norm[:, :, :-1]
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
spec_min = log_spec.max() - 8.0
|
||||
log_spec = torch.maximum(log_spec, spec_min)
|
||||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
def _to_onnx_stft(onnx_model):
|
||||
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
||||
node_idx = 0
|
||||
new_stft_nodes = []
|
||||
stft_norm_node = None
|
||||
for node in onnx_model.graph.node:
|
||||
if node.op_type == "StftNorm":
|
||||
stft_norm_node = node
|
||||
break
|
||||
node_idx += 1
|
||||
|
||||
if stft_norm_node is None:
|
||||
raise RuntimeError("Cannot find STFTNorm node in the graph")
|
||||
|
||||
make_node = onnx.helper.make_node
|
||||
replaced_nodes = [
|
||||
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
||||
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
|
||||
name='const_14')),
|
||||
make_node('Pad',
|
||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
||||
outputs=['pad_1_output_0'], mode='reflect'),
|
||||
make_node('STFT',
|
||||
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
|
||||
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
|
||||
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
|
||||
perm=[0, 2, 1, 3]),
|
||||
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
|
||||
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
|
||||
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
|
||||
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
|
||||
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
|
||||
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
|
||||
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
||||
name='slice_1'),
|
||||
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
|
||||
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
||||
name='gather_4', axis=3),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
||||
name='gather_5', axis=3),
|
||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
|
||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
|
||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
|
||||
]
|
||||
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
||||
new_stft_nodes.extend(replaced_nodes)
|
||||
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
|
||||
del onnx_model.graph.node[:]
|
||||
onnx_model.graph.node.extend(new_stft_nodes)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def _add_audio_pad_nodes():
|
||||
ec = pnp.ONNXElementContainer(17)
|
||||
_ox = ec.get_api()
|
||||
ec.add_node('Pad', [], [], [])
|
||||
|
||||
|
||||
def _torch_export(*arg, **kwargs):
|
||||
with io.BytesIO() as f:
|
||||
torch.onnx.export(*arg, f, **kwargs)
|
||||
return onnx.load_from_string(f.getvalue())
|
||||
|
||||
|
||||
def preprocessing(audio_data):
|
||||
if USE_AUDIO_DECODER:
|
||||
decoder = PyOrtFunction.from_customop("AudioDecoder")
|
||||
audio_pcm = torch.from_numpy(decoder(audio_data.unsqueeze_(0).numpy()))
|
||||
audio_pcm = torch.from_numpy(decoder(audio_data))
|
||||
else:
|
||||
audio_pcm = torch.from_numpy(audio_data)
|
||||
|
||||
prep_model_name = Path('whisper_pre.onnx')
|
||||
WhisperProcessing = WhisperPrePipeline()
|
||||
prep_model_name = 'whisper_pre.onnx'
|
||||
whisper_processing = WhisperPrePipeline()
|
||||
|
||||
model_args = (audio_pcm,)
|
||||
torch.onnx.export(
|
||||
WhisperProcessing,
|
||||
pre_model = _torch_export(
|
||||
whisper_processing,
|
||||
model_args,
|
||||
f=str(prep_model_name),
|
||||
input_names=["audio_pcm"],
|
||||
output_names=["log_mel"],
|
||||
do_constant_folding=True,
|
||||
export_params=True,
|
||||
opset_version=17,
|
||||
dynamic_axes={
|
||||
"audio_pcm": {0: "samp_len"},
|
||||
"audio_pcm": {1: "sample_len"},
|
||||
}
|
||||
)
|
||||
onnx.save_model(pre_model, prep_model_name)
|
||||
if USE_ONNX_STFT:
|
||||
pre_model = _to_onnx_stft(pre_model)
|
||||
|
||||
pre_f = PyOrtFunction.from_model(str(prep_model_name))
|
||||
pre_f = PyOrtFunction.from_model(pre_model)
|
||||
if not USE_AUDIO_DECODER:
|
||||
return pre_f(audio_pcm.numpy())
|
||||
return pre_f(audio_data)
|
||||
else:
|
||||
# pre_full = compose.merge_models(decoder.onnx_model,
|
||||
# onnx.load_model("whisper_pre.onnx"),
|
||||
# io_map=[("floatPCM", "audio_pcm")])
|
||||
# pre_f = PyOrtFunction.from_model(pre_full)
|
||||
pre_full = onnx.compose.merge_models(
|
||||
decoder.onnx_model,
|
||||
pre_model,
|
||||
io_map=[("floatPCM", "audio_pcm")])
|
||||
pre_f = PyOrtFunction.from_model(pre_full)
|
||||
|
||||
# onnx.compose has some bugs above, so we use the following workaround
|
||||
import copy
|
||||
new_graph_node = copy.deepcopy(pre_f.onnx_model.graph.node)
|
||||
del pre_f.onnx_model.graph.input[:]
|
||||
pre_f.onnx_model.graph.input.extend(decoder.onnx_model.graph.input)
|
||||
decoder.onnx_model.graph.node[0].output[0] = "audio_pcm"
|
||||
del pre_f.onnx_model.graph.node[:]
|
||||
pre_f.onnx_model.graph.node.extend(decoder.onnx_model.graph.node)
|
||||
pre_f.onnx_model.graph.node.extend(new_graph_node)
|
||||
onnx.save_model(pre_f.onnx_model, "whisper_aud_pre.onnx")
|
||||
pre_f = PyOrtFunction.from_model("whisper_aud_pre.onnx")
|
||||
return pre_f(audio_data.numpy())
|
||||
onnx.save_model(pre_f.onnx_model, "whisper_codec_pre.onnx")
|
||||
return pre_f(audio_data)
|
||||
|
||||
|
||||
def merge_models(core: str, output_model:str, audio_data):
|
||||
m_pre = onnx.load_model("whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
|
||||
m_core = onnx.load_model(core)
|
||||
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
|
||||
m2 = onnx.load_model("whisper_post.onnx")
|
||||
|
||||
m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
|
||||
bpe_decoder_node = m_all.graph.node.pop(-1)
|
||||
make_node = onnx.helper.make_node
|
||||
bpe_decoder_node.input.pop(0)
|
||||
bpe_decoder_node.input.extend(["generated_ids"])
|
||||
m_all.graph.node.extend([
|
||||
make_node('Constant', [], ['squeeze0_axes_0'], value_ints=[1]),
|
||||
make_node('Squeeze', ['sequences', 'squeeze0_axes_0'], ['squeeze0_output_0']),
|
||||
make_node('Cast', ['squeeze0_output_0'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||
bpe_decoder_node
|
||||
])
|
||||
onnx.save_model(m_all, output_model.replace(".onnx", "_.onnx"))
|
||||
optimize_model(m_all, output_model)
|
||||
print(f"The final merged model was saved as: {output_model}")
|
||||
|
||||
print("Verify the final model...")
|
||||
m_final = PyOrtFunction.from_model(output_model)
|
||||
output_text = m_final(audio_data,
|
||||
np.asarray([200]),
|
||||
np.asarray([0]), np.asarray([2]), np.asarray([1]),
|
||||
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
|
||||
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
|
||||
print(output_text)
|
||||
|
||||
|
||||
def postprocessing(token_ids, hf_processor):
|
||||
|
@ -156,7 +227,7 @@ def postprocessing(token_ids, hf_processor):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("checking the model...")
|
||||
print("preparing the model...")
|
||||
model_name = "openai/whisper-base.en"
|
||||
onnx_model_name = "whisper-base.en_beamsearch.onnx"
|
||||
if not Path(onnx_model_name).is_file():
|
||||
|
@ -165,7 +236,7 @@ if __name__ == '__main__':
|
|||
|
||||
_processor = WhisperProcessor.from_pretrained(model_name)
|
||||
if USE_ONNX_COREMODEL:
|
||||
# The onnx model can be gereated by the following command:
|
||||
# The onnx model can be generated by the following command:
|
||||
# python <ONNXRUNTIME_DIR>\onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py
|
||||
# -m "openai/whisper-base.en" -e
|
||||
# !only be valid after onnxruntime 1.15 or main branch of 04/04/2023
|
||||
|
@ -176,22 +247,26 @@ if __name__ == '__main__':
|
|||
test_file = util.get_test_data_file("../test/data", "1272-141231-0002.mp3")
|
||||
if USE_AUDIO_DECODER:
|
||||
with open(test_file, "rb") as _f:
|
||||
audio_data = torch.asarray(list(_f.read()), dtype=torch.uint8)
|
||||
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
|
||||
else:
|
||||
audio_data, _ = librosa.load(test_file)
|
||||
audio_blob, _ = librosa.load(test_file)
|
||||
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
|
||||
|
||||
log_mel = preprocessing(audio_data)
|
||||
log_mel = preprocessing(audio_blob)
|
||||
print(log_mel.shape)
|
||||
|
||||
input_features = numpy.expand_dims(log_mel, axis=0)
|
||||
input_features = log_mel
|
||||
if USE_ONNX_COREMODEL:
|
||||
ort_outputs = model(input_features, numpy.asarray([200]),
|
||||
numpy.asarray([0]), numpy.asarray([2]), numpy.asarray([1]),
|
||||
numpy.asarray([1.0], dtype=numpy.float32), numpy.asarray([1.0], dtype=numpy.float32),
|
||||
numpy.zeros(input_features.shape).astype(numpy.int32))
|
||||
ort_outputs = model(input_features, np.asarray([200]),
|
||||
np.asarray([0]), np.asarray([2]), np.asarray([1]),
|
||||
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
|
||||
np.zeros(input_features.shape).astype(np.int32))
|
||||
generated_ids = ort_outputs[0]
|
||||
else:
|
||||
generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
|
||||
|
||||
text = postprocessing(generated_ids[0], _processor)
|
||||
print(text)
|
||||
|
||||
print("build the final model...")
|
||||
merge_models(onnx_model_name, onnx_model_name.replace("beamsearch", "all"), audio_blob)
|
||||
|
|
Загрузка…
Ссылка в новой задаче