Fix the unit test failure with ONNX 1.14 package. (#428)

* Fix the unit test failure with ONNX 1.14 package.

* more tests

* Update whisper_e2e.py
This commit is contained in:
Wenbing Li 2023-05-08 11:37:54 -07:00 коммит произвёл GitHub
Родитель b7b8816dab
Коммит 43994eb34a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 113 добавлений и 13 удалений

Просмотреть файл

@ -1,5 +1,5 @@
import os
import fire
import argparse
import onnx
import numpy
@ -36,5 +36,22 @@ class ORTExtCommands:
print("The extensions loaded, status: OK.")
def main():
parser = argparse.ArgumentParser(description="ORT Extension commands")
parser.add_argument("command", choices=["run", "selfcheck"])
parser.add_argument("--model", default="model.onnx", help="Path to the ONNX model file")
parser.add_argument("--testdata-dir", help="Path to the test data directory")
parser.add_argument("args", nargs=argparse.REMAINDER, help="Additional arguments")
args = parser.parse_args()
ort_commands = ORTExtCommands(model=args.model, testdata_dir=args.testdata_dir)
if args.command == "run":
ort_commands.run(*args.args)
elif args.command == "selfcheck":
ort_commands.selfcheck(*args.args)
if __name__ == '__main__':
fire.Fire(ORTExtCommands)
main()

Просмотреть файл

@ -5,6 +5,8 @@ from typing import Any
from onnx.onnx_pb import TensorProto
from torch.onnx import TrainingMode, export as _export
from ._onnx_ops import OPSET_TO_IR_VERSION
def _export_f(model, *args,
opset_version=None,
@ -32,6 +34,9 @@ def _export_f(model, *args,
custom_opsets=custom_opsets)
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
for ops in mdl.opset_import:
if ops.domain in ('', 'ai.onnx'):
mdl.ir_version = OPSET_TO_IR_VERSION[ops.version]
if output_path is not None:
if output_seq > 0:
output_path.replace('.onnx', '.{}.onnx'.format(output_seq))

Просмотреть файл

@ -15,6 +15,8 @@ OPSET_TO_IR_VERSION = {
7: 3, 8: 3, 9: 4, 10: 5, 11: 6, 12: 7,
13: 7, 14: 7, 15: 8, 16: 8, 17: 8
}
if hasattr(helper, 'VERSION_TABLE'):
OPSET_TO_IR_VERSION = {row[2]: row[1] for row in helper.VERSION_TABLE}
def _get_main_opset_version(model):

Просмотреть файл

@ -1,6 +1,6 @@
import copy
import onnx
from onnx import numpy_helper
from onnx import helper, numpy_helper
from collections import namedtuple
@ -271,20 +271,32 @@ class ONNXModelUtils:
del _n.input[:]
_n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
name = ''
name = "_".join([_mdl.graph.name for _mdl in models])
domains = set()
_opset = []
for _mdl in models:
for _ops in _mdl.opset_import:
if _ops.domain not in domains:
domains.update([_ops.domain])
_opset.append(_ops)
name = name + '_' + _mdl.graph.name if name else _mdl.graph.name
domain = _ops.domain if _ops.domain else "ai.onnx"
if domain in domains:
if domain == "ai.onnx":
assert _ops.version == _opset[0].version, \
f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}"
else:
domains.add(domain)
if domain == "ai.onnx":
_opset.insert(0, _ops)
else:
_opset.append(_ops)
inits = cls._remove_unused_initializers(nodes, container.initializer)
helper = onnx.helper
g = helper.make_graph(nodes, name, inputs, outputs,
initializer=inits,
value_info=container.value_info)
m = helper.make_model(g, opset_imports=_opset)
if hasattr(helper, 'make_model_gen_version'):
# make_model_gen_version doesn't accept the custom domain.
m = helper.make_model_gen_version(g, opset_imports=_opset[:1])
m.opset_import.extend(_opset[1:])
else:
m = helper.make_model(g, opset_imports=_opset)
return m

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import onnx
import pathlib
import inspect
@ -57,3 +58,66 @@ def mel_filterbank(
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
def remove_unused_constants(subgraph):
nodes = [_n for _n in subgraph.node]
# Find the names of all input tensors for all nodes in the subgraph
input_tensors = set()
for node in nodes:
for input_name in node.input:
input_tensors.add(input_name)
# Remove Constant nodes whose output is not used by any other nodes
nodes_to_remove = []
for node in nodes:
if node.op_type == 'Constant':
output_name = node.output[0]
if output_name not in input_tensors:
nodes_to_remove.append(node)
for node in nodes_to_remove:
subgraph.node.remove(node)
# Recursively process subgraphs within this subgraph
for node in nodes:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
remove_unused_constants(attr.g)
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
remove_unused_constants(subgraph)
def remove_unused_initializers(subgraph, top_level_initializers=None):
if top_level_initializers is None:
top_level_initializers = []
remove_unused_constants(subgraph)
initializers = [_i for _i in subgraph.initializer]
nodes = subgraph.node
# Find the names of all input tensors for all nodes in the subgraph
input_tensors = set()
for node in nodes:
for input_name in node.input:
input_tensors.add(input_name)
# Combine top-level and current subgraph initializers
all_initializers = initializers + top_level_initializers
# Filter the initializers by checking if their names are in the list of used input tensors
used_initializers = [init for init in all_initializers if init.name in input_tensors]
# Update the subgraph's initializers
del subgraph.initializer[:]
subgraph.initializer.extend([init for init in used_initializers if init in initializers])
# Recursively process subgraphs within this subgraph
for node in nodes:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
remove_unused_initializers(attr.g, top_level_initializers)
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
remove_unused_initializers(subgraph, top_level_initializers)

Просмотреть файл

@ -166,6 +166,7 @@ def preprocessing(audio_data):
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
if USE_ONNX_STFT:
pre_model = _to_onnx_stft(pre_model)
util.remove_unused_initializers(pre_model.graph)
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
if not USE_AUDIO_DECODER:
@ -255,9 +256,8 @@ if __name__ == '__main__':
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
# The onnx model can be generated by the following command:
# python <ONNXRUNTIME_DIR>\onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py
# -m "openai/whisper-base.en" -e
# !only be valid after onnxruntime 1.15 or main branch of 04/04/2023
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
model = PyOrtFunction.from_model(args.model, cpu_only=True)
test_file = util.get_test_data_file(args.audio)