onnxruntime-extensions/onnxruntime_extensions/util.py

187 строки
7.0 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
util.py: Miscellaneous utility functions
"""
import onnx
import pathlib
import inspect
import numpy as np
# some util function for testing and tools
def get_test_data_file(*sub_dirs):
case_file = inspect.currentframe().f_back.f_code.co_filename
test_dir = pathlib.Path(case_file).parent
return str(test_dir.joinpath(*sub_dirs).resolve())
def read_file(path, mode='r'):
with open(str(path), mode) as file_content:
return file_content.read()
def mel_filterbank(
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
and it is Slaney normalized mel-scale filterbank.
"""
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
# the centers of the frequency bins for the DFT
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
mel = np.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
log_t = mel >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
mel_bins = freqs
mel_spacing = np.diff(mel_bins)
ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
for i in range(n_mels):
left = -ramps[i] / mel_spacing[i]
right = ramps[i + 2] / mel_spacing[i + 1]
# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
def remove_unused_constants(subgraph):
nodes = [_n for _n in subgraph.node]
# Find the names of all input tensors for all nodes in the subgraph
input_tensors = set()
for node in nodes:
for input_name in node.input:
input_tensors.add(input_name)
# Remove Constant nodes whose output is not used by any other nodes
nodes_to_remove = []
for node in nodes:
if node.op_type == 'Constant':
output_name = node.output[0]
if output_name not in input_tensors:
nodes_to_remove.append(node)
for node in nodes_to_remove:
subgraph.node.remove(node)
# Recursively process subgraphs within this subgraph
for node in nodes:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
remove_unused_constants(attr.g)
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
remove_unused_constants(subgraph)
def remove_unused_initializers(subgraph, top_level_initializers=None):
if top_level_initializers is None:
top_level_initializers = []
remove_unused_constants(subgraph)
initializers = [_i for _i in subgraph.initializer]
nodes = subgraph.node
# Find the names of all input tensors for all nodes in the subgraph
input_tensors = set()
for node in nodes:
for input_name in node.input:
input_tensors.add(input_name)
# Combine top-level and current subgraph initializers
all_initializers = initializers + top_level_initializers
# Filter the initializers by checking if their names are in the list of used input tensors
used_initializers = [
init for init in all_initializers if init.name in input_tensors]
# Update the subgraph's initializers
del subgraph.initializer[:]
subgraph.initializer.extend(
[init for init in used_initializers if init in initializers])
# Recursively process subgraphs within this subgraph
for node in nodes:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
remove_unused_initializers(attr.g, top_level_initializers)
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
remove_unused_initializers(
subgraph, top_level_initializers)
def quick_merge(*models, connection_indices=None):
"""
This function merges multiple ONNX models into a single model, without performing any ONNX format checks.
Parameters:
*models (onnx.ModelProto): Varargs parameter representing the ONNX models to be merged.
connection_indices (List[List[int]], optional): A nested list specifying which outputs in one model should connect
to which inputs in the next model, based on their indices.
If not provided, it's assumed that the sequence of outputs in
one model exactly matches the sequence of inputs in the next model.
Returns:
merged_model (onnx.ModelProto): The merged ONNX model.
Raises:
ValueError: If there is any conflict in tensor names, either in initializers or in nodes, including subgraphs.
If there is any conflict in opset versions for the same domain.
"""
merged_graph = models[0].graph
# Dictionary to store unique opsets
opset_imports = {
opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
# Iterate over all other models and merge
for model_idx, model in enumerate(models[1:], start=1):
if connection_indices is None:
io_map = [(out.name, in_.name) for out, in_ in zip(
models[model_idx - 1].graph.output, model.graph.input)]
else:
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
for out_idx, in_idx in connection_indices[model_idx - 1]]
merged_graph = onnx.compose.merge_graphs(merged_graph, model.graph, io_map)
for opset in model.opset_import:
if not opset.domain:
opset.domain = "ai.onnx"
if opset.domain in opset_imports and opset_imports[opset.domain].version != opset.version:
raise ValueError(f"Conflict in opset versions for domain '{opset.domain}': " +
f"model {model_idx} has version {opset.version}, while previous model has version " +
f"{opset_imports[opset.domain].version}.")
else:
opset_imports[opset.domain] = opset
default_opset = opset_imports.pop("ai.onnx", None)
merged_model = onnx.helper.make_model_gen_version(merged_graph,
opset_imports=[
default_opset],
producer_name='ONNX Model Merger')
merged_model.opset_import.extend(opset_imports.values())
return merged_model