diff --git a/onnxruntime_extensions/_ortapi2.py b/onnxruntime_extensions/_ortapi2.py index 064be66d..41074399 100644 --- a/onnxruntime_extensions/_ortapi2.py +++ b/onnxruntime_extensions/_ortapi2.py @@ -24,6 +24,9 @@ def get_opset_version_from_ort(): } ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) + max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get) + if ort_ver_string > max_ver: + ort_ver_string = max_ver return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11) @@ -41,8 +44,6 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), class OrtPyFunction: - __name__ = 'OrtPyFunction' - @classmethod def get_ort_session_options(cls): # ONNXRuntime has an issue to support reusing the SessionOptions object. diff --git a/tutorials/whisper_e2e.py b/tutorials/whisper_e2e.py index de17dc69..58a07372 100644 --- a/tutorials/whisper_e2e.py +++ b/tutorials/whisper_e2e.py @@ -1,18 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import io import onnx -import numpy import torch +import numpy as np + +from pathlib import Path +from onnx import numpy_helper from transformers import WhisperProcessor, WhisperForConditionalGeneration -# from onnx import compose -from pathlib import Path -from onnxruntime_extensions import PyOrtFunction, util +from onnxruntime_extensions import PyOrtFunction, pnp, util, optimize_model from onnxruntime_extensions.cvt import HFTokenizerConverter # the flags for pre-processing -USE_ONNX_STFT = False +USE_ONNX_STFT = True USE_ONNX_COREMODEL = True USE_AUDIO_DECODER = True @@ -51,22 +53,6 @@ class CustomOpStftNorm(torch.autograd.Function): return stft.abs() ** 2 -class CustomOpStft(torch.autograd.Function): - @staticmethod - def symbolic(g, self, n_fft, hop_length, window): - t_frame_step = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64)) - t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) - return g.op("STFT", self, t_frame_step, window, t_frame_size) - - @staticmethod - def forward(ctx, audio, n_fft, hop_length, window): - win_length = window.shape[0] - stft = torch.stft(audio, n_fft, hop_length, win_length, window, - center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True) - stft = torch.permute(stft, (0, 2, 1)) - return torch.view_as_real(stft) - - class WhisperPrePipeline(torch.nn.Module): def __init__(self): super().__init__() @@ -74,75 +60,160 @@ class WhisperPrePipeline(torch.nn.Module): self.mel_filters = torch.from_numpy(util.mel_filterbank(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS)) def forward(self, audio_pcm: torch.Tensor): - if USE_AUDIO_DECODER: - audio_pcm = audio_pcm.squeeze(0) - - pad_len = N_SAMPLES - audio_pcm.shape[0] - audio_pcm = torch.nn.functional.pad(audio_pcm, (0, pad_len), mode='constant', value=0) - audio_pcm = audio_pcm.unsqueeze(0) - - if USE_ONNX_STFT: - stft = CustomOpStft.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window) - stft_norm = stft[..., 0] ** 2 + stft[..., 1] ** 2 - stft_norm = torch.permute(stft_norm, (0, 2, 1)) - else: - stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window) - - stft_norm.squeeze_(0) - magnitudes = stft_norm[:, :-1] + stft_norm = CustomOpStftNorm.apply(audio_pcm, N_FFT, HOP_LENGTH, self.window) + magnitudes = stft_norm[:, :, :-1] mel_spec = self.mel_filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + spec_min = log_spec.max() - 8.0 + log_spec = torch.maximum(log_spec, spec_min) + spec_shape = log_spec.shape + padding_spec = torch.ones(spec_shape[0], + spec_shape[1], (N_SAMPLES // HOP_LENGTH - spec_shape[2]), dtype=torch.float) + padding_spec *= spec_min + log_spec = torch.cat((log_spec, padding_spec), dim=2) log_spec = (log_spec + 4.0) / 4.0 return log_spec +def _to_onnx_stft(onnx_model): + """Convert custom-op STFT-Norm to ONNX STFT""" + node_idx = 0 + new_stft_nodes = [] + stft_norm_node = None + for node in onnx_model.graph.node: + if node.op_type == "StftNorm": + stft_norm_node = node + break + node_idx += 1 + + if stft_norm_node is None: + raise RuntimeError("Cannot find STFTNorm node in the graph") + + make_node = onnx.helper.make_node + replaced_nodes = [ + make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14', + value=numpy_helper.from_array(np.array([0, N_FFT // 2, 0, N_FFT // 2], dtype='int64'), + name='const_14')), + make_node('Pad', + inputs=[stft_norm_node.input[0], 'const_14_output_0'], + outputs=['pad_1_output_0'], mode='reflect'), + make_node('STFT', + inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]], + outputs=['stft_output_0'], name='stft', domain='', onesided=1), + make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1', + perm=[0, 2, 1, 3]), + make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17', + value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')), + make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18', + value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')), + make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19', + value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')), + make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20', + value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')), + make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0', + 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'], + name='slice_1'), + make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0), + make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1), + make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'], + name='gather_4', axis=3), + make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'], + name='gather_5', axis=3), + make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'), + make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'), + make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'), + ] + new_stft_nodes.extend(onnx_model.graph.node[:node_idx]) + new_stft_nodes.extend(replaced_nodes) + new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:]) + del onnx_model.graph.node[:] + onnx_model.graph.node.extend(new_stft_nodes) + onnx.checker.check_model(onnx_model) + return onnx_model + + +def _add_audio_pad_nodes(): + ec = pnp.ONNXElementContainer(17) + _ox = ec.get_api() + ec.add_node('Pad', [], [], []) + + +def _torch_export(*arg, **kwargs): + with io.BytesIO() as f: + torch.onnx.export(*arg, f, **kwargs) + return onnx.load_from_string(f.getvalue()) + + def preprocessing(audio_data): if USE_AUDIO_DECODER: decoder = PyOrtFunction.from_customop("AudioDecoder") - audio_pcm = torch.from_numpy(decoder(audio_data.unsqueeze_(0).numpy())) + audio_pcm = torch.from_numpy(decoder(audio_data)) else: audio_pcm = torch.from_numpy(audio_data) - prep_model_name = Path('whisper_pre.onnx') - WhisperProcessing = WhisperPrePipeline() + prep_model_name = 'whisper_pre.onnx' + whisper_processing = WhisperPrePipeline() model_args = (audio_pcm,) - torch.onnx.export( - WhisperProcessing, + pre_model = _torch_export( + whisper_processing, model_args, - f=str(prep_model_name), input_names=["audio_pcm"], output_names=["log_mel"], do_constant_folding=True, export_params=True, opset_version=17, dynamic_axes={ - "audio_pcm": {0: "samp_len"}, + "audio_pcm": {1: "sample_len"}, } ) + onnx.save_model(pre_model, prep_model_name) + if USE_ONNX_STFT: + pre_model = _to_onnx_stft(pre_model) - pre_f = PyOrtFunction.from_model(str(prep_model_name)) + pre_f = PyOrtFunction.from_model(pre_model) if not USE_AUDIO_DECODER: - return pre_f(audio_pcm.numpy()) + return pre_f(audio_data) else: - # pre_full = compose.merge_models(decoder.onnx_model, - # onnx.load_model("whisper_pre.onnx"), - # io_map=[("floatPCM", "audio_pcm")]) - # pre_f = PyOrtFunction.from_model(pre_full) + pre_full = onnx.compose.merge_models( + decoder.onnx_model, + pre_model, + io_map=[("floatPCM", "audio_pcm")]) + pre_f = PyOrtFunction.from_model(pre_full) - # onnx.compose has some bugs above, so we use the following workaround - import copy - new_graph_node = copy.deepcopy(pre_f.onnx_model.graph.node) - del pre_f.onnx_model.graph.input[:] - pre_f.onnx_model.graph.input.extend(decoder.onnx_model.graph.input) - decoder.onnx_model.graph.node[0].output[0] = "audio_pcm" - del pre_f.onnx_model.graph.node[:] - pre_f.onnx_model.graph.node.extend(decoder.onnx_model.graph.node) - pre_f.onnx_model.graph.node.extend(new_graph_node) - onnx.save_model(pre_f.onnx_model, "whisper_aud_pre.onnx") - pre_f = PyOrtFunction.from_model("whisper_aud_pre.onnx") - return pre_f(audio_data.numpy()) + onnx.save_model(pre_f.onnx_model, "whisper_codec_pre.onnx") + return pre_f(audio_data) + + +def merge_models(core: str, output_model:str, audio_data): + m_pre = onnx.load_model("whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx") + m_core = onnx.load_model(core) + m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")]) + m2 = onnx.load_model("whisper_post.onnx") + + m_all = onnx.compose.merge_models(m1, m2, io_map=[("sequences", "ids")]) + bpe_decoder_node = m_all.graph.node.pop(-1) + make_node = onnx.helper.make_node + bpe_decoder_node.input.pop(0) + bpe_decoder_node.input.extend(["generated_ids"]) + m_all.graph.node.extend([ + make_node('Constant', [], ['squeeze0_axes_0'], value_ints=[1]), + make_node('Squeeze', ['sequences', 'squeeze0_axes_0'], ['squeeze0_output_0']), + make_node('Cast', ['squeeze0_output_0'], ["generated_ids"], to=onnx.TensorProto.INT64), + bpe_decoder_node + ]) + onnx.save_model(m_all, output_model.replace(".onnx", "_.onnx")) + optimize_model(m_all, output_model) + print(f"The final merged model was saved as: {output_model}") + + print("Verify the final model...") + m_final = PyOrtFunction.from_model(output_model) + output_text = m_final(audio_data, + np.asarray([200]), + np.asarray([0]), np.asarray([2]), np.asarray([1]), + np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32), + np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32)) + print(output_text) def postprocessing(token_ids, hf_processor): @@ -156,7 +227,7 @@ def postprocessing(token_ids, hf_processor): if __name__ == '__main__': - print("checking the model...") + print("preparing the model...") model_name = "openai/whisper-base.en" onnx_model_name = "whisper-base.en_beamsearch.onnx" if not Path(onnx_model_name).is_file(): @@ -165,7 +236,7 @@ if __name__ == '__main__': _processor = WhisperProcessor.from_pretrained(model_name) if USE_ONNX_COREMODEL: - # The onnx model can be gereated by the following command: + # The onnx model can be generated by the following command: # python \onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py # -m "openai/whisper-base.en" -e # !only be valid after onnxruntime 1.15 or main branch of 04/04/2023 @@ -176,22 +247,26 @@ if __name__ == '__main__': test_file = util.get_test_data_file("../test/data", "1272-141231-0002.mp3") if USE_AUDIO_DECODER: with open(test_file, "rb") as _f: - audio_data = torch.asarray(list(_f.read()), dtype=torch.uint8) + audio_blob = np.asarray(list(_f.read()), dtype=np.uint8) else: - audio_data, _ = librosa.load(test_file) + audio_blob, _ = librosa.load(test_file) + audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension - log_mel = preprocessing(audio_data) + log_mel = preprocessing(audio_blob) print(log_mel.shape) - input_features = numpy.expand_dims(log_mel, axis=0) + input_features = log_mel if USE_ONNX_COREMODEL: - ort_outputs = model(input_features, numpy.asarray([200]), - numpy.asarray([0]), numpy.asarray([2]), numpy.asarray([1]), - numpy.asarray([1.0], dtype=numpy.float32), numpy.asarray([1.0], dtype=numpy.float32), - numpy.zeros(input_features.shape).astype(numpy.int32)) + ort_outputs = model(input_features, np.asarray([200]), + np.asarray([0]), np.asarray([2]), np.asarray([1]), + np.asarray([1.0], dtype=np.float32), np.asarray([1.0], dtype=np.float32), + np.zeros(input_features.shape).astype(np.int32)) generated_ids = ort_outputs[0] else: generated_ids = model.generate(torch.from_numpy(input_features)).numpy() text = postprocessing(generated_ids[0], _processor) print(text) + + print("build the final model...") + merge_models(onnx_model_name, onnx_model_name.replace("beamsearch", "all"), audio_blob)