diff --git a/tutorials/whisper_e2e.py b/tutorials/whisper_e2e.py index bbf3188f..07779e9d 100644 --- a/tutorials/whisper_e2e.py +++ b/tutorials/whisper_e2e.py @@ -56,6 +56,48 @@ def process_test_file(): return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m +def get_model_inputs(ort_session, audio_data): + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) + print(ort_names) + + inputs = [ + audio_data, # audio_stream/input_features + np.asarray([200], dtype=np.int32), # max_length + np.asarray([0], dtype=np.int32), # min_length + np.asarray([2], dtype=np.int32), # num_beams + np.asarray([1], dtype=np.int32), # num_return_sequences + np.asarray([1.0], dtype=np.float32), # length_penalty + np.asarray([1.0], dtype=np.float32), # repetition_penalty + ] + required_input_names = {"audio_stream", "input_features", "max_length", "min_length", "num_beams", + "num_return_sequences", "length_penalty", "repetition_penalty"} + + # Add optional inputs if present in model + batch_size = 1 + N_MELS = 80, N_FRAMES = 3000 + + vocab_size = 51864 if ".en" in MODEL_NAME else 51865 + decoder_start_token_id = 50257 if ".en" in MODEL_NAME else 50258 + + for name in ort_names: + if name in required_input_names: + continue + elif name == "vocab_mask": + inputs.append(np.ones(vocab_size, dtype=np.int32)) + elif name == "prefix_vocab_mask": + inputs.append(np.ones((batch_size, vocab_size), dtype=np.int32)) + elif name == "attention_mask": + # For older ORT versions that have the dummy attention mask input for the beam search op + inputs.append(np.zeros((batch_size, N_MELS, N_FRAMES), dtype=np.int32)) + elif name == "decoder_input_ids": + inputs.append(np.array([[decoder_start_token_id]], dtype=np.int32)) + elif name == "logits_processor": + inputs.append(np.array([1], dtype=np.int32)) + else: + raise NotImplementedError(f"'{name}' input is not supported") + return inputs + + def main(): check_onnx_version() export_onnx_model() @@ -63,13 +105,9 @@ def main(): # Apply core ONNX model fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True) - token_seq = fn_core(log_mel, - np.asarray([200], dtype=np.int32), - np.asarray([0], dtype=np.int32), - np.asarray([2], dtype=np.int32), - np.asarray([1], dtype=np.int32), - np.asarray([1.0], dtype=np.float32), - np.asarray([1.0], dtype=np.float32)) + fn_core_ort_session = fn_core._ensure_ort_session() + model_inputs = get_model_inputs(fn_core_ort_session, log_mel) + token_seq = fn_core(*model_inputs) print(token_seq.shape) # Apply post processing @@ -84,14 +122,11 @@ def main(): # Test the final model raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8) - text = OrtPyFunction.from_model(final_m, cpu_only=True)( - np.expand_dims(raw_audio, axis=0), - np.asarray([200], dtype=np.int32), - np.asarray([0], dtype=np.int32), - np.asarray([2], dtype=np.int32), - np.asarray([1], dtype=np.int32), - np.asarray([1.0], dtype=np.float32), - np.asarray([1.0], dtype=np.float32)) + raw_audio = np.expand_dims(raw_audio, axis=0) + e2e_model = OrtPyFunction.from_model(final_m, cpu_only=True) + e2e_model_ort_session = e2e_model._ensure_ort_session() + model_inputs = get_model_inputs(e2e_model_ort_session, raw_audio) + text = e2e_model(*model_inputs) print(text)