Fix the unit test failure with ONNX 1.14 package. (#428)
* Fix the unit test failure with ONNX 1.14 package. * more tests * Update whisper_e2e.py
This commit is contained in:
Родитель
b7b8816dab
Коммит
43994eb34a
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import fire
|
||||
import argparse
|
||||
import onnx
|
||||
import numpy
|
||||
|
||||
|
@ -36,5 +36,22 @@ class ORTExtCommands:
|
|||
print("The extensions loaded, status: OK.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ORT Extension commands")
|
||||
parser.add_argument("command", choices=["run", "selfcheck"])
|
||||
parser.add_argument("--model", default="model.onnx", help="Path to the ONNX model file")
|
||||
parser.add_argument("--testdata-dir", help="Path to the test data directory")
|
||||
parser.add_argument("args", nargs=argparse.REMAINDER, help="Additional arguments")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ort_commands = ORTExtCommands(model=args.model, testdata_dir=args.testdata_dir)
|
||||
|
||||
if args.command == "run":
|
||||
ort_commands.run(*args.args)
|
||||
elif args.command == "selfcheck":
|
||||
ort_commands.selfcheck(*args.args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(ORTExtCommands)
|
||||
main()
|
||||
|
|
|
@ -5,6 +5,8 @@ from typing import Any
|
|||
from onnx.onnx_pb import TensorProto
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
|
||||
from ._onnx_ops import OPSET_TO_IR_VERSION
|
||||
|
||||
|
||||
def _export_f(model, *args,
|
||||
opset_version=None,
|
||||
|
@ -32,6 +34,9 @@ def _export_f(model, *args,
|
|||
custom_opsets=custom_opsets)
|
||||
|
||||
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
|
||||
for ops in mdl.opset_import:
|
||||
if ops.domain in ('', 'ai.onnx'):
|
||||
mdl.ir_version = OPSET_TO_IR_VERSION[ops.version]
|
||||
if output_path is not None:
|
||||
if output_seq > 0:
|
||||
output_path.replace('.onnx', '.{}.onnx'.format(output_seq))
|
||||
|
|
|
@ -15,6 +15,8 @@ OPSET_TO_IR_VERSION = {
|
|||
7: 3, 8: 3, 9: 4, 10: 5, 11: 6, 12: 7,
|
||||
13: 7, 14: 7, 15: 8, 16: 8, 17: 8
|
||||
}
|
||||
if hasattr(helper, 'VERSION_TABLE'):
|
||||
OPSET_TO_IR_VERSION = {row[2]: row[1] for row in helper.VERSION_TABLE}
|
||||
|
||||
|
||||
def _get_main_opset_version(model):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx import helper, numpy_helper
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
|
@ -271,20 +271,32 @@ class ONNXModelUtils:
|
|||
del _n.input[:]
|
||||
_n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
|
||||
|
||||
name = ''
|
||||
name = "_".join([_mdl.graph.name for _mdl in models])
|
||||
domains = set()
|
||||
_opset = []
|
||||
for _mdl in models:
|
||||
for _ops in _mdl.opset_import:
|
||||
if _ops.domain not in domains:
|
||||
domains.update([_ops.domain])
|
||||
_opset.append(_ops)
|
||||
name = name + '_' + _mdl.graph.name if name else _mdl.graph.name
|
||||
domain = _ops.domain if _ops.domain else "ai.onnx"
|
||||
if domain in domains:
|
||||
if domain == "ai.onnx":
|
||||
assert _ops.version == _opset[0].version, \
|
||||
f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}"
|
||||
else:
|
||||
domains.add(domain)
|
||||
if domain == "ai.onnx":
|
||||
_opset.insert(0, _ops)
|
||||
else:
|
||||
_opset.append(_ops)
|
||||
|
||||
inits = cls._remove_unused_initializers(nodes, container.initializer)
|
||||
helper = onnx.helper
|
||||
g = helper.make_graph(nodes, name, inputs, outputs,
|
||||
initializer=inits,
|
||||
value_info=container.value_info)
|
||||
m = helper.make_model(g, opset_imports=_opset)
|
||||
|
||||
if hasattr(helper, 'make_model_gen_version'):
|
||||
# make_model_gen_version doesn't accept the custom domain.
|
||||
m = helper.make_model_gen_version(g, opset_imports=_opset[:1])
|
||||
m.opset_import.extend(_opset[1:])
|
||||
else:
|
||||
m = helper.make_model(g, opset_imports=_opset)
|
||||
return m
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import onnx
|
||||
import pathlib
|
||||
import inspect
|
||||
|
||||
|
@ -57,3 +58,66 @@ def mel_filterbank(
|
|||
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
|
||||
fbank *= energy_norm[:, np.newaxis]
|
||||
return fbank
|
||||
|
||||
|
||||
def remove_unused_constants(subgraph):
|
||||
nodes = [_n for _n in subgraph.node]
|
||||
|
||||
# Find the names of all input tensors for all nodes in the subgraph
|
||||
input_tensors = set()
|
||||
for node in nodes:
|
||||
for input_name in node.input:
|
||||
input_tensors.add(input_name)
|
||||
|
||||
# Remove Constant nodes whose output is not used by any other nodes
|
||||
nodes_to_remove = []
|
||||
for node in nodes:
|
||||
if node.op_type == 'Constant':
|
||||
output_name = node.output[0]
|
||||
if output_name not in input_tensors:
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
for node in nodes_to_remove:
|
||||
subgraph.node.remove(node)
|
||||
|
||||
# Recursively process subgraphs within this subgraph
|
||||
for node in nodes:
|
||||
for attr in node.attribute:
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
remove_unused_constants(attr.g)
|
||||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for subgraph in attr.graphs:
|
||||
remove_unused_constants(subgraph)
|
||||
|
||||
|
||||
def remove_unused_initializers(subgraph, top_level_initializers=None):
|
||||
if top_level_initializers is None:
|
||||
top_level_initializers = []
|
||||
remove_unused_constants(subgraph)
|
||||
initializers = [_i for _i in subgraph.initializer]
|
||||
nodes = subgraph.node
|
||||
|
||||
# Find the names of all input tensors for all nodes in the subgraph
|
||||
input_tensors = set()
|
||||
for node in nodes:
|
||||
for input_name in node.input:
|
||||
input_tensors.add(input_name)
|
||||
|
||||
# Combine top-level and current subgraph initializers
|
||||
all_initializers = initializers + top_level_initializers
|
||||
|
||||
# Filter the initializers by checking if their names are in the list of used input tensors
|
||||
used_initializers = [init for init in all_initializers if init.name in input_tensors]
|
||||
|
||||
# Update the subgraph's initializers
|
||||
del subgraph.initializer[:]
|
||||
subgraph.initializer.extend([init for init in used_initializers if init in initializers])
|
||||
|
||||
# Recursively process subgraphs within this subgraph
|
||||
for node in nodes:
|
||||
for attr in node.attribute:
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
remove_unused_initializers(attr.g, top_level_initializers)
|
||||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for subgraph in attr.graphs:
|
||||
remove_unused_initializers(subgraph, top_level_initializers)
|
||||
|
|
|
@ -166,6 +166,7 @@ def preprocessing(audio_data):
|
|||
onnx.save_model(pre_model, os.path.join(root_dir, prep_model_name))
|
||||
if USE_ONNX_STFT:
|
||||
pre_model = _to_onnx_stft(pre_model)
|
||||
util.remove_unused_initializers(pre_model.graph)
|
||||
|
||||
pre_f = PyOrtFunction.from_model(pre_model, cpu_only=True)
|
||||
if not USE_AUDIO_DECODER:
|
||||
|
@ -255,9 +256,8 @@ if __name__ == '__main__':
|
|||
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
||||
|
||||
# The onnx model can be generated by the following command:
|
||||
# python <ONNXRUNTIME_DIR>\onnxruntime\python\tools\transformers\models\whisper\convert_to_onnx.py
|
||||
# -m "openai/whisper-base.en" -e
|
||||
# !only be valid after onnxruntime 1.15 or main branch of 04/04/2023
|
||||
# python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m "openai/whisper-base.en" -e
|
||||
# !!! only be valid after onnxruntime 1.15 or nightly build after 05/05/2023
|
||||
model = PyOrtFunction.from_model(args.model, cpu_only=True)
|
||||
|
||||
test_file = util.get_test_data_file(args.audio)
|
||||
|
|
Загрузка…
Ссылка в новой задаче