Update inputs in Whisper E2E script (#511)

This PR updates the inputs for the inference pass to show the required and optional ones.
This commit is contained in:
kunal-vaishnavi 2023-08-08 15:46:21 -07:00 коммит произвёл GitHub
Родитель ab5710f82d
Коммит c8bb9e8abd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 50 добавлений и 15 удалений

Просмотреть файл

@ -56,6 +56,48 @@ def process_test_file():
return fn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
def get_model_inputs(ort_session, audio_data):
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
print(ort_names)
inputs = [
audio_data, # audio_stream/input_features
np.asarray([200], dtype=np.int32), # max_length
np.asarray([0], dtype=np.int32), # min_length
np.asarray([2], dtype=np.int32), # num_beams
np.asarray([1], dtype=np.int32), # num_return_sequences
np.asarray([1.0], dtype=np.float32), # length_penalty
np.asarray([1.0], dtype=np.float32), # repetition_penalty
]
required_input_names = {"audio_stream", "input_features", "max_length", "min_length", "num_beams",
"num_return_sequences", "length_penalty", "repetition_penalty"}
# Add optional inputs if present in model
batch_size = 1
N_MELS = 80, N_FRAMES = 3000
vocab_size = 51864 if ".en" in MODEL_NAME else 51865
decoder_start_token_id = 50257 if ".en" in MODEL_NAME else 50258
for name in ort_names:
if name in required_input_names:
continue
elif name == "vocab_mask":
inputs.append(np.ones(vocab_size, dtype=np.int32))
elif name == "prefix_vocab_mask":
inputs.append(np.ones((batch_size, vocab_size), dtype=np.int32))
elif name == "attention_mask":
# For older ORT versions that have the dummy attention mask input for the beam search op
inputs.append(np.zeros((batch_size, N_MELS, N_FRAMES), dtype=np.int32))
elif name == "decoder_input_ids":
inputs.append(np.array([[decoder_start_token_id]], dtype=np.int32))
elif name == "logits_processor":
inputs.append(np.array([1], dtype=np.int32))
else:
raise NotImplementedError(f"'{name}' input is not supported")
return inputs
def main():
check_onnx_version()
export_onnx_model()
@ -63,13 +105,9 @@ def main():
# Apply core ONNX model
fn_core = OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
token_seq = fn_core(log_mel,
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32))
fn_core_ort_session = fn_core._ensure_ort_session()
model_inputs = get_model_inputs(fn_core_ort_session, log_mel)
token_seq = fn_core(*model_inputs)
print(token_seq.shape)
# Apply post processing
@ -84,14 +122,11 @@ def main():
# Test the final model
raw_audio = np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
text = OrtPyFunction.from_model(final_m, cpu_only=True)(
np.expand_dims(raw_audio, axis=0),
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32))
raw_audio = np.expand_dims(raw_audio, axis=0)
e2e_model = OrtPyFunction.from_model(final_m, cpu_only=True)
e2e_model_ort_session = e2e_model._ensure_ort_session()
model_inputs = get_model_inputs(e2e_model_ort_session, raw_audio)
text = e2e_model(*model_inputs)
print(text)