Update inputs in Whisper E2E script (#511)
This PR updates the inputs for the inference pass to show the required and optional ones.
This commit is contained in:
Родитель
ab5710f82d
Коммит
c8bb9e8abd
|
@ -56,6 +56,48 @@ def process_test_file():
|
|||
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
|
||||
|
||||
|
||||
def get_model_inputs(ort_session, audio_data):
|
||||
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
|
||||
print(ort_names)
|
||||
|
||||
inputs = [
|
||||
audio_data, # audio_stream/input_features
|
||||
np.asarray([200], dtype=np.int32), # max_length
|
||||
np.asarray([0], dtype=np.int32), # min_length
|
||||
np.asarray([2], dtype=np.int32), # num_beams
|
||||
np.asarray([1], dtype=np.int32), # num_return_sequences
|
||||
np.asarray([1.0], dtype=np.float32), # length_penalty
|
||||
np.asarray([1.0], dtype=np.float32), # repetition_penalty
|
||||
]
|
||||
required_input_names = {"audio_stream", "input_features", "max_length", "min_length", "num_beams",
|
||||
"num_return_sequences", "length_penalty", "repetition_penalty"}
|
||||
|
||||
# Add optional inputs if present in model
|
||||
batch_size = 1
|
||||
N_MELS = 80, N_FRAMES = 3000
|
||||
|
||||
vocab_size = 51864 if ".en" in MODEL_NAME else 51865
|
||||
decoder_start_token_id = 50257 if ".en" in MODEL_NAME else 50258
|
||||
|
||||
for name in ort_names:
|
||||
if name in required_input_names:
|
||||
continue
|
||||
elif name == "vocab_mask":
|
||||
inputs.append(np.ones(vocab_size, dtype=np.int32))
|
||||
elif name == "prefix_vocab_mask":
|
||||
inputs.append(np.ones((batch_size, vocab_size), dtype=np.int32))
|
||||
elif name == "attention_mask":
|
||||
# For older ORT versions that have the dummy attention mask input for the beam search op
|
||||
inputs.append(np.zeros((batch_size, N_MELS, N_FRAMES), dtype=np.int32))
|
||||
elif name == "decoder_input_ids":
|
||||
inputs.append(np.array([[decoder_start_token_id]], dtype=np.int32))
|
||||
elif name == "logits_processor":
|
||||
inputs.append(np.array([1], dtype=np.int32))
|
||||
else:
|
||||
raise NotImplementedError(f"'{name}' input is not supported")
|
||||
return inputs
|
||||
|
||||
|
||||
def main():
|
||||
check_onnx_version()
|
||||
export_onnx_model()
|
||||
|
@ -63,13 +105,9 @@ def main():
|
|||
|
||||
# Apply core ONNX model
|
||||
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
|
||||
token_seq = fn_core(log_mel,
|
||||
np.asarray([200], dtype=np.int32),
|
||||
np.asarray([0], dtype=np.int32),
|
||||
np.asarray([2], dtype=np.int32),
|
||||
np.asarray([1], dtype=np.int32),
|
||||
np.asarray([1.0], dtype=np.float32),
|
||||
np.asarray([1.0], dtype=np.float32))
|
||||
fn_core_ort_session = fn_core._ensure_ort_session()
|
||||
model_inputs = get_model_inputs(fn_core_ort_session, log_mel)
|
||||
token_seq = fn_core(*model_inputs)
|
||||
print(token_seq.shape)
|
||||
|
||||
# Apply post processing
|
||||
|
@ -84,14 +122,11 @@ def main():
|
|||
|
||||
# Test the final model
|
||||
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
|
||||
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
|
||||
np.expand_dims(raw_audio, axis=0),
|
||||
np.asarray([200], dtype=np.int32),
|
||||
np.asarray([0], dtype=np.int32),
|
||||
np.asarray([2], dtype=np.int32),
|
||||
np.asarray([1], dtype=np.int32),
|
||||
np.asarray([1.0], dtype=np.float32),
|
||||
np.asarray([1.0], dtype=np.float32))
|
||||
raw_audio = np.expand_dims(raw_audio, axis=0)
|
||||
e2e_model = OrtPyFunction.from_model(final_m, cpu_only=True)
|
||||
e2e_model_ort_session = e2e_model._ensure_ort_session()
|
||||
model_inputs = get_model_inputs(e2e_model_ort_session, raw_audio)
|
||||
text = e2e_model(*model_inputs)
|
||||
print(text)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче