Add a merge step in whisper end-to-end script and fixed some issues (#399)

* add merged models in whisper model

* verify the final model
This commit is contained in:
Wenbing Li 2023-04-17 16:37:06 -07:00 коммит произвёл GitHub
Родитель 77e63c8845
Коммит 711774db6b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 151 добавлений и 75 удалений

Просмотреть файл

@ -24,6 +24,9 @@ def get_opset_version_from_ort():
} }
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
if ort_ver_string > max_ver:
ort_ver_string = max_ver
return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11) return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
@ -41,8 +44,6 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
class OrtPyFunction: class OrtPyFunction:
__name__ = 'OrtPyFunction'
@classmethod @classmethod
def get_ort_session_options(cls): def get_ort_session_options(cls):
# ONNXRuntime has an issue to support reusing the SessionOptions object. # ONNXRuntime has an issue to support reusing the SessionOptions object.

Просмотреть файл

@ -1,18 +1,20 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import io
import onnx import onnx
import numpy
import torch import torch
import numpy as np
from pathlib import Path
from onnx import numpy_helper
from transformers import WhisperProcessor, WhisperForConditionalGeneration from transformers import WhisperProcessor, WhisperForConditionalGeneration
# from onnx import compose from onnxruntime_extensions import PyOrtFunction, pnp, util, optimize_model
from pathlib import Path
from onnxruntime_extensions import PyOrtFunction, util
from onnxruntime_extensions.cvt import HFTokenizerConverter from onnxruntime_extensions.cvt import HFTokenizerConverter
# the flags for pre-processing # the flags for pre-processing
USE_ONNX_STFT = False USE_ONNX_STFT = True
USE_ONNX_COREMODEL = True USE_ONNX_COREMODEL = True
USE_AUDIO_DECODER = True USE_AUDIO_DECODER = True
@ -51,22 +53,6 @@ class CustomOpStftNorm(torch.autograd.Function):
return stft.abs() ** 2 return stft.abs() ** 2
class CustomOpStft(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_frame_step = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("STFT", self, t_frame_step, window, t_frame_size)
@staticmethod
def forward(ctx, audio, n_fft, hop_length, window):
win_length = window.shape[0]
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
stft = torch.permute(stft, (0, 2, 1))
return torch.view_as_real(stft)
class WhisperPrePipeline(torch.nn.Module): class WhisperPrePipeline(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -74,75 +60,160 @@ class WhisperPrePipeline(torch.nn.Module):
self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS)) self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS))
def forward(self, audio_pcm: torch.Tensor): def forward(self, audio_pcm: torch.Tensor):
if USE_AUDIO_DECODER: stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
audio_pcm = audio_pcm.squeeze(0) magnitudes = stft_norm[:, :, :-1]
pad_len = N_SAMPLES - audio_pcm.shape[0]
audio_pcm = torch.nn.functional.pad(audio_pcm, (0, pad_len), mode='constant', value=0)
audio_pcm = audio_pcm.unsqueeze(0)
if USE_ONNX_STFT:
stft = CustomOpStft.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
stft_norm = stft[..., 0] ** 2 + stft[..., 1] ** 2
stft_norm = torch.permute(stft_norm, (0, 2, 1))
else:
stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window)
stft_norm.squeeze_(0)
magnitudes = stft_norm[:, :-1]
mel_spec = self.mel_filters @ magnitudes mel_spec = self.mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) spec_min = log_spec.max() - 8.0
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
def _to_onnx_stft(onnx_model):
"""Convert custom-op STFT-Norm to ONNX STFT"""
node_idx = 0
new_stft_nodes = []
stft_norm_node = None
for node in onnx_model.graph.node:
if node.op_type == "StftNorm":
stft_norm_node = node
break
node_idx += 1
if stft_norm_node is None:
raise RuntimeError("Cannot find STFTNorm node in the graph")
make_node = onnx.helper.make_node
replaced_nodes = [
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'),
name='const_14')),
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
outputs=['pad_1_output_0'], mode='reflect'),
make_node('STFT',
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
perm=[0, 2, 1, 3]),
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
del onnx_model.graph.node[:]
onnx_model.graph.node.extend(new_stft_nodes)
onnx.checker.check_model(onnx_model)
return onnx_model
def _add_audio_pad_nodes():
ec = pnp.ONNXElementContainer(17)
_ox = ec.get_api()
ec.add_node('Pad', [], [], [])
def _torch_export(*arg, **kwargs):
with io.BytesIO() as f:
torch.onnx.export(*arg, f, **kwargs)
return onnx.load_from_string(f.getvalue())
def preprocessing(audio_data): def preprocessing(audio_data):
if USE_AUDIO_DECODER: if USE_AUDIO_DECODER:
decoder = PyOrtFunction.from_customop("AudioDecoder") decoder = PyOrtFunction.from_customop("AudioDecoder")
audio_pcm = torch.from_numpy(decoder(audio_data.unsqueeze_(0).numpy())) audio_pcm = torch.from_numpy(decoder(audio_data))
else: else:
audio_pcm = torch.from_numpy(audio_data) audio_pcm = torch.from_numpy(audio_data)
prep_model_name = Path('whisper_pre.onnx') prep_model_name = 'whisper_pre.onnx'
WhisperProcessing = WhisperPrePipeline() whisper_processing = WhisperPrePipeline()
model_args = (audio_pcm,) model_args = (audio_pcm,)
torch.onnx.export( pre_model = _torch_export(
WhisperProcessing, whisper_processing,
model_args, model_args,
f=str(prep_model_name),
input_names=["audio_pcm"], input_names=["audio_pcm"],
output_names=["log_mel"], output_names=["log_mel"],
do_constant_folding=True, do_constant_folding=True,
export_params=True, export_params=True,
opset_version=17, opset_version=17,
dynamic_axes={ dynamic_axes={
"audio_pcm": {0: "samp_len"}, "audio_pcm": {1: "sample_len"},
} }
) )
onnx.save_model(pre_model, prep_model_name)
if USE_ONNX_STFT:
pre_model = _to_onnx_stft(pre_model)
pre_f = PyOrtFunction.from_model(str(prep_model_name)) pre_f = PyOrtFunction.from_model(pre_model)
if not USE_AUDIO_DECODER: if not USE_AUDIO_DECODER:
return pre_f(audio_pcm.numpy()) return pre_f(audio_data)
else: else:
# pre_full = compose.merge_models(decoder.onnx_model, pre_full = onnx.compose.merge_models(
# onnx.load_model("whisper_pre.onnx"), decoder.onnx_model,
# io_map=[("floatPCM", "audio_pcm")]) pre_model,
# pre_f = PyOrtFunction.from_model(pre_full) io_map=[("floatPCM", "audio_pcm")])
pre_f = PyOrtFunction.from_model(pre_full)
# onnx.compose has some bugs above, so we use the following workaround onnx.save_model(pre_f.onnx_model, "whisper_codec_pre.onnx")
import copy return pre_f(audio_data)
new_graph_node = copy.deepcopy(pre_f.onnx_model.graph.node)
del pre_f.onnx_model.graph.input[:]
pre_f.onnx_model.graph.input.extend(decoder.onnx_model.graph.input) def merge_models(core: str, output_model:str, audio_data):
decoder.onnx_model.graph.node[0].output[0] = "audio_pcm" m_pre = onnx.load_model("whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
del pre_f.onnx_model.graph.node[:] m_core = onnx.load_model(core)
pre_f.onnx_model.graph.node.extend(decoder.onnx_model.graph.node) m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
pre_f.onnx_model.graph.node.extend(new_graph_node) m2 = onnx.load_model("whisper_post.onnx")
onnx.save_model(pre_f.onnx_model, "whisper_aud_pre.onnx")
pre_f = PyOrtFunction.from_model("whisper_aud_pre.onnx") m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")])
return pre_f(audio_data.numpy()) bpe_decoder_node = m_all.graph.node.pop(-1)
make_node = onnx.helper.make_node
bpe_decoder_node.input.pop(0)
bpe_decoder_node.input.extend(["generated_ids"])
m_all.graph.node.extend([
make_node('Constant', [], ['squeeze0_axes_0'], value_ints=[1]),
make_node('Squeeze', ['sequences', 'squeeze0_axes_0'], ['squeeze0_output_0']),
make_node('Cast', ['squeeze0_output_0'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpe_decoder_node
])
onnx.save_model(m_all, output_model.replace(".onnx", "_.onnx"))
optimize_model(m_all, output_model)
print(f"The final merged model was saved as: {output_model}")
print("Verify the final model...")
m_final = PyOrtFunction.from_model(output_model)
output_text = m_final(audio_data,
np.asarray([200]),
np.asarray([0]), np.asarray([2]), np.asarray([1]),
np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32))
print(output_text)
def postprocessing(token_ids, hf_processor): def postprocessing(token_ids, hf_processor):
@ -156,7 +227,7 @@ def postprocessing(token_ids, hf_processor):
if __name__ == '__main__': if __name__ == '__main__':
print("checking the model...") print("preparing the model...")
model_name = "openai/whisper-base.en" model_name = "openai/whisper-base.en"
onnx_model_name = "whisper-base.en_beamsearch.onnx" onnx_model_name = "whisper-base.en_beamsearch.onnx"
if not Path(onnx_model_name).is_file(): if not Path(onnx_model_name).is_file():
@ -165,7 +236,7 @@ if __name__ == '__main__':
_processor = WhisperProcessor.from_pretrained(model_name) _processor = WhisperProcessor.from_pretrained(model_name)
if USE_ONNX_COREMODEL: if USE_ONNX_COREMODEL:
# The onnx model can be gereated by the following command: # The onnx model can be generated by the following command:
# python <ONNXRUNTIME_DIR>\onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py # python <ONNXRUNTIME_DIR>\onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py
# -m "openai/whisper-base.en" -e # -m "openai/whisper-base.en" -e
# !only be valid after onnxruntime 1.15 or main branch of 04/04/2023 # !only be valid after onnxruntime 1.15 or main branch of 04/04/2023
@ -176,22 +247,26 @@ if __name__ == '__main__':
test_file = util.get_test_data_file("../test/data", "1272-141231-0002.mp3") test_file = util.get_test_data_file("../test/data", "1272-141231-0002.mp3")
if USE_AUDIO_DECODER: if USE_AUDIO_DECODER:
with open(test_file, "rb") as _f: with open(test_file, "rb") as _f:
audio_data = torch.asarray(list(_f.read()), dtype=torch.uint8) audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
else: else:
audio_data, _ = librosa.load(test_file) audio_blob, _ = librosa.load(test_file)
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
log_mel = preprocessing(audio_data) log_mel = preprocessing(audio_blob)
print(log_mel.shape) print(log_mel.shape)
input_features = numpy.expand_dims(log_mel, axis=0) input_features = log_mel
if USE_ONNX_COREMODEL: if USE_ONNX_COREMODEL:
ort_outputs = model(input_features, numpy.asarray([200]), ort_outputs = model(input_features, np.asarray([200]),
numpy.asarray([0]), numpy.asarray([2]), numpy.asarray([1]), np.asarray([0]), np.asarray([2]), np.asarray([1]),
numpy.asarray([1.0], dtype=numpy.float32), numpy.asarray([1.0], dtype=numpy.float32), np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32),
numpy.zeros(input_features.shape).astype(numpy.int32)) np.zeros(input_features.shape).astype(np.int32))
generated_ids = ort_outputs[0] generated_ids = ort_outputs[0]
else: else:
generated_ids = model.generate(torch.from_numpy(input_features)).numpy() generated_ids = model.generate(torch.from_numpy(input_features)).numpy()
text = postprocessing(generated_ids[0], _processor) text = postprocessing(generated_ids[0], _processor)
print(text) print(text)
print("build the final model...")
merge_models(onnx_model_name, onnx_model_name.replace("beamsearch", "all"), audio_blob)