356 строки
14 KiB
Python
356 строки
14 KiB
Python
import copy
|
|
import onnx
|
|
import torch
|
|
import warnings
|
|
import numpy as np
|
|
from onnx import helper, mapping
|
|
from collections import namedtuple
|
|
from .._ortapi2 import OrtPyFunction
|
|
from ._builder import is_path as _is_path
|
|
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
|
from ._tensor import tensor_from_onnx, tensor_from_torch, tensor_set_session
|
|
|
|
|
|
def _is_numpy_object(x):
|
|
return isinstance(x, (np.ndarray, np.generic))
|
|
|
|
|
|
def _is_numpy_string_type(arr):
|
|
return arr.dtype.kind in {'U', 'S'}
|
|
|
|
|
|
def _is_string_type(x):
|
|
if not _is_numpy_object(x):
|
|
x = np.array(x)
|
|
return _is_numpy_string_type(x)
|
|
|
|
|
|
class ONNXModelUtils:
|
|
@staticmethod
|
|
def _rename_iter(iterables, prefix_name, inplace=False):
|
|
new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
|
|
for iz_ in new_iz:
|
|
iz_.name = "{}_{}".format(prefix_name, iz_.name)
|
|
return new_iz
|
|
|
|
@classmethod
|
|
def _rename_graph(cls, graph, prefix, graph_or_container):
|
|
def io_rename(node, prefix_name, idx):
|
|
new_node = copy.deepcopy(node)
|
|
if not node.name:
|
|
new_node.name = "{}_op{}".format(prefix_name, idx)
|
|
|
|
del new_node.input[:]
|
|
new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
|
|
del new_node.output[:]
|
|
new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
|
|
return new_node
|
|
|
|
assert prefix is not None, 'The graph prefix could not be None'
|
|
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
|
|
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
|
|
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
|
|
|
|
@classmethod
|
|
def _process_node_body(cls, node, prefix):
|
|
if all(attr.name != 'body' for attr in node.attribute):
|
|
return node
|
|
|
|
def _process_attr(attr, prefix_name):
|
|
if attr.name == 'body':
|
|
new_attr = copy.deepcopy(attr)
|
|
del new_attr.g.value_info[:]
|
|
del new_attr.g.node[:]
|
|
new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
|
|
cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
|
|
cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
|
|
return new_attr
|
|
else:
|
|
return attr
|
|
|
|
attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
|
|
del node.attribute[:]
|
|
node.attribute.extend(attr_list)
|
|
return node
|
|
|
|
@classmethod
|
|
def unfold_model_node(cls, container: ONNXElementContainer):
|
|
top_containter = container
|
|
while top_containter.parent is not None: # only one opset_import in the model.
|
|
top_containter = top_containter.parent
|
|
|
|
nodes = container.nodes
|
|
model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
|
|
onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
|
|
|
|
for node in model_nodes.values():
|
|
renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
|
|
onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
|
|
|
|
top_containter.node_domain_version_pair_sets.update([(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
|
|
return onnx_nodes
|
|
|
|
@classmethod
|
|
def topological_sort(cls, container, nodes, inputs, outputs):
|
|
op_output_map = {}
|
|
DynNode = namedtuple('DynNode', ['name', 'output'])
|
|
input_nodes = [DynNode(name='placeholder',
|
|
output=[nm_.name for nm_ in inputs] +
|
|
[it_.name for it_ in container.initializers])] +\
|
|
[nd_ for nd_ in nodes if nd_.op_type == 'Constant']
|
|
|
|
for nd_ in nodes + input_nodes:
|
|
for ky_ in nd_.output:
|
|
op_output_map[ky_] = nd_
|
|
|
|
edges = {}
|
|
for op in nodes:
|
|
for x in op.input:
|
|
if x == '':
|
|
continue
|
|
try:
|
|
predecessor = op_output_map[x]
|
|
except KeyError:
|
|
raise RuntimeError(
|
|
"{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
|
|
|
|
val = edges.get(predecessor.name, [])
|
|
val.append(op)
|
|
edges[predecessor.name] = val
|
|
|
|
for y_ in outputs:
|
|
op = op_output_map[y_.name].name
|
|
if op not in edges:
|
|
edges[op] = []
|
|
|
|
visited = set()
|
|
sorted_nodes = []
|
|
unfinished_nodes = set()
|
|
|
|
def recursive_helper(node):
|
|
if node.name in visited:
|
|
return
|
|
|
|
if node.name in unfinished_nodes:
|
|
raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
|
|
|
|
unfinished_nodes.add(node.name)
|
|
if node.name in edges: # if the node's output is not in the Graph output.
|
|
assert node.name != '', 'this topological-sort depends on the unique node name.'
|
|
for successor in edges[node.name]:
|
|
recursive_helper(successor)
|
|
|
|
unfinished_nodes.remove(node.name)
|
|
visited.add(node.name)
|
|
if node is not input_nodes[0]:
|
|
sorted_nodes.insert(0, node)
|
|
|
|
for nd_ in input_nodes:
|
|
recursive_helper(nd_)
|
|
|
|
return sorted_nodes
|
|
|
|
@staticmethod
|
|
def value_info_from_numpy(name, value):
|
|
dtype = onnx.onnx_pb.TensorProto.STRING if \
|
|
_is_numpy_string_type(value) else mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
|
|
return helper.make_tensor_value_info(name, dtype, shape=value.shape)
|
|
|
|
@staticmethod
|
|
def model_from_ops(container, ops, ts_from, ts_to):
|
|
all_inputs = []
|
|
all_outputs = []
|
|
iz_needed = set()
|
|
iz_set = set(iz_.name for iz_ in container.initializer)
|
|
for op in ops:
|
|
iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
|
|
all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
|
|
all_outputs.extend(ot_ for ot_ in op.output)
|
|
|
|
intersections = set(all_inputs).intersection(set(all_outputs))
|
|
assert set(all_inputs).difference(intersections) == set(ts_.name for ts_ in ts_from), \
|
|
"The input list is different from the calculated from the op nodes"
|
|
assert set(all_outputs).difference(intersections) == set(ts_.name for ts_ in ts_to), \
|
|
"The output list is different from the calculated from the op nodes"
|
|
|
|
final_iz = [iz_ for iz_ in container.initializers if iz_.name in iz_needed]
|
|
graph = helper.make_graph(ops, 'dyngraph', ts_from, ts_to, final_iz)
|
|
oxml = make_model_ex(graph,
|
|
container.node_domain_version_pair_sets,
|
|
container.target_opset)
|
|
return oxml
|
|
|
|
|
|
class ONNXTraceSession:
|
|
activated_sessions = []
|
|
|
|
def __init__(self, target_opset):
|
|
self.container = ONNXElementContainer(target_opset)
|
|
self.inputs = []
|
|
self.outputs = []
|
|
|
|
def __enter__(self):
|
|
assert len(self.activated_sessions) > 0 and self.activated_sessions[-1] is self, "trace not started?"
|
|
return self
|
|
|
|
# need this exit to close the session
|
|
def __exit__(self, exec_type, exec_value, exec_tb):
|
|
tensor_set_session(None)
|
|
assert self is self.activated_sessions.pop()
|
|
|
|
@classmethod
|
|
def trace_for_onnx(cls, *inputs, names=None, target_opset=11) -> 'ONNXTraceSession':
|
|
"""
|
|
Starting the trace all tensor computation for ONNX graph generation.
|
|
:param inputs: the input tensor, could a torch.Tensor or a numpy ndarray.
|
|
:param names: The input names the ONNX graph
|
|
:param target_opset: The ONNX model opset_version
|
|
:return: A tracing session object, in most case, it should be used in the with statement.
|
|
"""
|
|
self = ONNXTraceSession(target_opset)
|
|
self.activated_sessions.append(self)
|
|
tensor_set_session(self)
|
|
|
|
np_inputs = [np.array(x) if _is_string_type(x) else x for x in inputs]
|
|
np_inputs = [
|
|
x if isinstance(x, (np.ndarray, np.generic, torch.Tensor)) or _is_string_type(x)
|
|
else torch.tensor(x) for x in np_inputs]
|
|
itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
|
|
else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
|
|
if names is None:
|
|
names = []
|
|
if len(inputs) != len(names):
|
|
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
|
|
names.extend([''] * (len(inputs) - len(names)))
|
|
for idx_ in range(len(inputs)):
|
|
names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
|
|
num = min(len(itensors), len(names))
|
|
for idx_ in range(num):
|
|
itensors[idx_].name = names[idx_]
|
|
self.inputs = itensors
|
|
return self
|
|
|
|
def runops(self, ts_from, ts_to):
|
|
nodes = self.container.nodes
|
|
inset = set(ts_.name for ts_ in ts_from)
|
|
inset.update(iz_.name for iz_ in self.container.initializer)
|
|
outset = set(ts_.name for ts_ in ts_to)
|
|
missing_ts_set = set()
|
|
node_num = len(nodes) - 1
|
|
while node_num >= 0:
|
|
node = nodes[node_num]
|
|
for ot_ in node.output:
|
|
if ot_ in missing_ts_set:
|
|
missing_ts_set.remove(ot_)
|
|
elif ot_ in outset:
|
|
outset.remove(ot_)
|
|
for it_ in node.input:
|
|
if it_ not in inset:
|
|
missing_ts_set.add(it_)
|
|
if len(missing_ts_set) == 0:
|
|
break
|
|
node_num -= 1
|
|
|
|
assert len(outset) == 0, "Some output cannot be in the node list."
|
|
assert len(missing_ts_set) == 0, "Some input cannot be in the node list."
|
|
collected_nodes = nodes[node_num:]
|
|
vi_input = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
|
|
for ts_ in ts_from]
|
|
vi_output = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
|
|
for ts_ in ts_to]
|
|
oxml = ONNXModelUtils.model_from_ops(self.container,
|
|
collected_nodes,
|
|
vi_input,
|
|
vi_output)
|
|
result = None
|
|
try:
|
|
oxfunc = OrtPyFunction.from_model(oxml)
|
|
result = oxfunc(*[ts_.numpy() for ts_ in ts_from])
|
|
finally:
|
|
if result is None:
|
|
onnx.save_model(oxml, 'mt_debmodel.onnx')
|
|
|
|
return result if isinstance(result, (list, tuple)) else [result], oxml
|
|
|
|
def get_inputs(self):
|
|
return self.inputs
|
|
|
|
def stack_container(self):
|
|
assert self.container is not None, "Stacked container must be in another one."
|
|
sub_container = ONNXElementContainer(self.container.target_opset, self.container)
|
|
self.container = sub_container
|
|
return self.container
|
|
|
|
def pop_container(self):
|
|
assert self.container.parent is not None, "Cannot pop the root container."
|
|
self.container = self.container.parent
|
|
return self.container
|
|
|
|
@staticmethod
|
|
def build_graph(container, ts_inputs, ts_outputs, graph_name=None):
|
|
# some constant ops are created to simulate the tensors generated from the runtime in the loop,
|
|
# so we need to remove the node here
|
|
to_del = []
|
|
input_names = {it_.name: None for it_ in ts_inputs}
|
|
for idx_, nd_ in enumerate(container.nodes):
|
|
if nd_.op_type == 'Constant' and list(nd_.output)[0] in input_names:
|
|
to_del.append(idx_)
|
|
|
|
for idx_ in to_del[::-1]:
|
|
container.nodes.pop(idx_)
|
|
|
|
graph_name = container.get_unique_operator_name('subg') if not graph_name else graph_name
|
|
nodes = ONNXModelUtils.unfold_model_node(container)
|
|
nodes = ONNXModelUtils.topological_sort(container, nodes, ts_inputs, ts_outputs)
|
|
|
|
for vi_ in container.value_info:
|
|
if vi_.name in input_names:
|
|
input_names[vi_.name] = vi_
|
|
|
|
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape())
|
|
if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs]
|
|
outputs = [helper.make_tensor_value_info(so.name, so.onnx_type,
|
|
so.get_shape()) for so in ts_outputs]
|
|
|
|
graph = helper.make_graph(nodes, graph_name, inputs,
|
|
outputs, container.initializers)
|
|
return graph
|
|
|
|
def build_model(self, model_name=None, doc_string=None) -> onnx.ModelProto:
|
|
model_name = 'tcm' if model_name is None else model_name
|
|
doc_string = '' if doc_string is None else doc_string
|
|
container = self.container
|
|
graph = self.build_graph(container, self.inputs, self.outputs, model_name)
|
|
onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets,
|
|
container.target_opset, doc_string=doc_string)
|
|
return onnx_model
|
|
|
|
def save_as_onnx(self, file_like_or_path, outputs, model_name=None, doc_string=None):
|
|
"""
|
|
Build the ONNX model from the traced computation graph.
|
|
:param file_like_or_path: an io.BytesIO like object or a file path
|
|
:param outputs: the output tensor to be specified as the ONNX graph output,
|
|
Could be a string if there are multiple output tensors.
|
|
:param model_name: The ONNX model internal name
|
|
:param doc_string: The doc string for the model
|
|
:return: A ONNX ModelProto object.
|
|
"""
|
|
if len(self.outputs) == 0 and outputs is None:
|
|
raise RuntimeError("No output of the graph specified.")
|
|
|
|
if len(self.outputs) == 0:
|
|
self.outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs]
|
|
|
|
m = self.build_model(model_name, doc_string)
|
|
|
|
if file_like_or_path is not None:
|
|
if _is_path(file_like_or_path):
|
|
with open(file_like_or_path, 'wb') as f:
|
|
f.write(m.SerializeToString())
|
|
else:
|
|
f = file_like_or_path
|
|
f.write(m.SerializeToString())
|
|
f.flush()
|
|
|
|
return m
|