onnxruntime-extensions/onnxruntime_extensions/onnxprocess/_session.py

356 строки
14 KiB
Python

import copy
import onnx
import torch
import warnings
import numpy as np
from onnx import helper, mapping
from collections import namedtuple
from .._ortapi2 import OrtPyFunction
from ._builder import is_path as _is_path
from ._onnx_ops import ONNXElementContainer, make_model_ex
from ._tensor import tensor_from_onnx, tensor_from_torch, tensor_set_session
def _is_numpy_object(x):
return isinstance(x, (np.ndarray, np.generic))
def _is_numpy_string_type(arr):
return arr.dtype.kind in {'U', 'S'}
def _is_string_type(x):
if not _is_numpy_object(x):
x = np.array(x)
return _is_numpy_string_type(x)
class ONNXModelUtils:
@staticmethod
def _rename_iter(iterables, prefix_name, inplace=False):
new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
for iz_ in new_iz:
iz_.name = "{}_{}".format(prefix_name, iz_.name)
return new_iz
@classmethod
def _rename_graph(cls, graph, prefix, graph_or_container):
def io_rename(node, prefix_name, idx):
new_node = copy.deepcopy(node)
if not node.name:
new_node.name = "{}_op{}".format(prefix_name, idx)
del new_node.input[:]
new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
del new_node.output[:]
new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
return new_node
assert prefix is not None, 'The graph prefix could not be None'
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
@classmethod
def _process_node_body(cls, node, prefix):
if all(attr.name != 'body' for attr in node.attribute):
return node
def _process_attr(attr, prefix_name):
if attr.name == 'body':
new_attr = copy.deepcopy(attr)
del new_attr.g.value_info[:]
del new_attr.g.node[:]
new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
return new_attr
else:
return attr
attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
del node.attribute[:]
node.attribute.extend(attr_list)
return node
@classmethod
def unfold_model_node(cls, container: ONNXElementContainer):
top_containter = container
while top_containter.parent is not None: # only one opset_import in the model.
top_containter = top_containter.parent
nodes = container.nodes
model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
for node in model_nodes.values():
renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
top_containter.node_domain_version_pair_sets.update([(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
return onnx_nodes
@classmethod
def topological_sort(cls, container, nodes, inputs, outputs):
op_output_map = {}
DynNode = namedtuple('DynNode', ['name', 'output'])
input_nodes = [DynNode(name='placeholder',
output=[nm_.name for nm_ in inputs] +
[it_.name for it_ in container.initializers])] +\
[nd_ for nd_ in nodes if nd_.op_type == 'Constant']
for nd_ in nodes + input_nodes:
for ky_ in nd_.output:
op_output_map[ky_] = nd_
edges = {}
for op in nodes:
for x in op.input:
if x == '':
continue
try:
predecessor = op_output_map[x]
except KeyError:
raise RuntimeError(
"{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
val = edges.get(predecessor.name, [])
val.append(op)
edges[predecessor.name] = val
for y_ in outputs:
op = op_output_map[y_.name].name
if op not in edges:
edges[op] = []
visited = set()
sorted_nodes = []
unfinished_nodes = set()
def recursive_helper(node):
if node.name in visited:
return
if node.name in unfinished_nodes:
raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
unfinished_nodes.add(node.name)
if node.name in edges: # if the node's output is not in the Graph output.
assert node.name != '', 'this topological-sort depends on the unique node name.'
for successor in edges[node.name]:
recursive_helper(successor)
unfinished_nodes.remove(node.name)
visited.add(node.name)
if node is not input_nodes[0]:
sorted_nodes.insert(0, node)
for nd_ in input_nodes:
recursive_helper(nd_)
return sorted_nodes
@staticmethod
def value_info_from_numpy(name, value):
dtype = onnx.onnx_pb.TensorProto.STRING if \
_is_numpy_string_type(value) else mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
return helper.make_tensor_value_info(name, dtype, shape=value.shape)
@staticmethod
def model_from_ops(container, ops, ts_from, ts_to):
all_inputs = []
all_outputs = []
iz_needed = set()
iz_set = set(iz_.name for iz_ in container.initializer)
for op in ops:
iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
all_outputs.extend(ot_ for ot_ in op.output)
intersections = set(all_inputs).intersection(set(all_outputs))
assert set(all_inputs).difference(intersections) == set(ts_.name for ts_ in ts_from), \
"The input list is different from the calculated from the op nodes"
assert set(all_outputs).difference(intersections) == set(ts_.name for ts_ in ts_to), \
"The output list is different from the calculated from the op nodes"
final_iz = [iz_ for iz_ in container.initializers if iz_.name in iz_needed]
graph = helper.make_graph(ops, 'dyngraph', ts_from, ts_to, final_iz)
oxml = make_model_ex(graph,
container.node_domain_version_pair_sets,
container.target_opset)
return oxml
class ONNXTraceSession:
activated_sessions = []
def __init__(self, target_opset):
self.container = ONNXElementContainer(target_opset)
self.inputs = []
self.outputs = []
def __enter__(self):
assert len(self.activated_sessions) > 0 and self.activated_sessions[-1] is self, "trace not started?"
return self
# need this exit to close the session
def __exit__(self, exec_type, exec_value, exec_tb):
tensor_set_session(None)
assert self is self.activated_sessions.pop()
@classmethod
def trace_for_onnx(cls, *inputs, names=None, target_opset=11) -> 'ONNXTraceSession':
"""
Starting the trace all tensor computation for ONNX graph generation.
:param inputs: the input tensor, could a torch.Tensor or a numpy ndarray.
:param names: The input names the ONNX graph
:param target_opset: The ONNX model opset_version
:return: A tracing session object, in most case, it should be used in the with statement.
"""
self = ONNXTraceSession(target_opset)
self.activated_sessions.append(self)
tensor_set_session(self)
np_inputs = [np.array(x) if _is_string_type(x) else x for x in inputs]
np_inputs = [
x if isinstance(x, (np.ndarray, np.generic, torch.Tensor)) or _is_string_type(x)
else torch.tensor(x) for x in np_inputs]
itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
if names is None:
names = []
if len(inputs) != len(names):
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
names.extend([''] * (len(inputs) - len(names)))
for idx_ in range(len(inputs)):
names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
num = min(len(itensors), len(names))
for idx_ in range(num):
itensors[idx_].name = names[idx_]
self.inputs = itensors
return self
def runops(self, ts_from, ts_to):
nodes = self.container.nodes
inset = set(ts_.name for ts_ in ts_from)
inset.update(iz_.name for iz_ in self.container.initializer)
outset = set(ts_.name for ts_ in ts_to)
missing_ts_set = set()
node_num = len(nodes) - 1
while node_num >= 0:
node = nodes[node_num]
for ot_ in node.output:
if ot_ in missing_ts_set:
missing_ts_set.remove(ot_)
elif ot_ in outset:
outset.remove(ot_)
for it_ in node.input:
if it_ not in inset:
missing_ts_set.add(it_)
if len(missing_ts_set) == 0:
break
node_num -= 1
assert len(outset) == 0, "Some output cannot be in the node list."
assert len(missing_ts_set) == 0, "Some input cannot be in the node list."
collected_nodes = nodes[node_num:]
vi_input = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
for ts_ in ts_from]
vi_output = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
for ts_ in ts_to]
oxml = ONNXModelUtils.model_from_ops(self.container,
collected_nodes,
vi_input,
vi_output)
result = None
try:
oxfunc = OrtPyFunction.from_model(oxml)
result = oxfunc(*[ts_.numpy() for ts_ in ts_from])
finally:
if result is None:
onnx.save_model(oxml, 'mt_debmodel.onnx')
return result if isinstance(result, (list, tuple)) else [result], oxml
def get_inputs(self):
return self.inputs
def stack_container(self):
assert self.container is not None, "Stacked container must be in another one."
sub_container = ONNXElementContainer(self.container.target_opset, self.container)
self.container = sub_container
return self.container
def pop_container(self):
assert self.container.parent is not None, "Cannot pop the root container."
self.container = self.container.parent
return self.container
@staticmethod
def build_graph(container, ts_inputs, ts_outputs, graph_name=None):
# some constant ops are created to simulate the tensors generated from the runtime in the loop,
# so we need to remove the node here
to_del = []
input_names = {it_.name: None for it_ in ts_inputs}
for idx_, nd_ in enumerate(container.nodes):
if nd_.op_type == 'Constant' and list(nd_.output)[0] in input_names:
to_del.append(idx_)
for idx_ in to_del[::-1]:
container.nodes.pop(idx_)
graph_name = container.get_unique_operator_name('subg') if not graph_name else graph_name
nodes = ONNXModelUtils.unfold_model_node(container)
nodes = ONNXModelUtils.topological_sort(container, nodes, ts_inputs, ts_outputs)
for vi_ in container.value_info:
if vi_.name in input_names:
input_names[vi_.name] = vi_
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape())
if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs]
outputs = [helper.make_tensor_value_info(so.name, so.onnx_type,
so.get_shape()) for so in ts_outputs]
graph = helper.make_graph(nodes, graph_name, inputs,
outputs, container.initializers)
return graph
def build_model(self, model_name=None, doc_string=None) -> onnx.ModelProto:
model_name = 'tcm' if model_name is None else model_name
doc_string = '' if doc_string is None else doc_string
container = self.container
graph = self.build_graph(container, self.inputs, self.outputs, model_name)
onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets,
container.target_opset, doc_string=doc_string)
return onnx_model
def save_as_onnx(self, file_like_or_path, outputs, model_name=None, doc_string=None):
"""
Build the ONNX model from the traced computation graph.
:param file_like_or_path: an io.BytesIO like object or a file path
:param outputs: the output tensor to be specified as the ONNX graph output,
Could be a string if there are multiple output tensors.
:param model_name: The ONNX model internal name
:param doc_string: The doc string for the model
:return: A ONNX ModelProto object.
"""
if len(self.outputs) == 0 and outputs is None:
raise RuntimeError("No output of the graph specified.")
if len(self.outputs) == 0:
self.outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs]
m = self.build_model(model_name, doc_string)
if file_like_or_path is not None:
if _is_path(file_like_or_path):
with open(file_like_or_path, 'wb') as f:
f.write(m.SerializeToString())
else:
f = file_like_or_path
f.write(m.SerializeToString())
f.flush()
return m