onnxruntime-extensions/onnxruntime_extensions/_ocos.py

187 строки
6.2 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################
"""
_ocos.py: PythonOp implementation
"""
import sys
import copy
import onnx
from onnx import helper
from ._extensions_pydll import ( # noqa
PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain)
def get_library_path():
"""
The custom operator library binary path
:return: A string of this library path.
"""
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
return mod.__file__
class Opdef:
_odlist = {}
def __init__(self, op_type, func):
self.op_type = op_type
self.body = func
self._id = id(self)
@staticmethod
def declare(*args, **kwargs):
if len(args) > 0 and hasattr(args[0], '__call__'):
raise RuntimeError("Unexpected arguments {}.".format(args))
# return Opdef._create(args[0])
return lambda f: Opdef.create(f, *args, **kwargs)
@staticmethod
def create(func, *args, **kwargs):
name = kwargs.get('op_type', None)
op_type = name or func.__name__
opdef = Opdef(op_type, func)
od_id = id(opdef)
# Tells python this object cannot be destroyed
# because it is also stored in C++ container.
Opdef._odlist[od_id] = opdef
opdef._nativedef = PyCustomOpDef()
opdef._nativedef.op_type = op_type
opdef._nativedef.obj_id = od_id
inputs = kwargs.get('inputs', None)
if inputs is None:
inputs = [PyCustomOpDef.dt_float]
opdef._nativedef.input_types = inputs
outputs = kwargs.get('outputs', None)
if outputs is None:
outputs = [PyCustomOpDef.dt_float]
opdef._nativedef.output_types = outputs
attrs = kwargs.get('attrs', None)
if attrs is None:
attrs = {}
elif isinstance(attrs, (list, tuple)):
attrs = {k: PyCustomOpDef.dt_string for k in attrs}
opdef._nativedef.attrs = attrs
add_custom_op(opdef._nativedef)
return opdef
def __call__(self, *args, **kwargs):
return self.body(*args, **kwargs)
def cast_attributes(self, attributes):
res = {}
for k, v in attributes.items():
if self._nativedef.attrs[k] == PyCustomOpDef.dt_int64:
res[k] = int(v)
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_float:
res[k] = float(v)
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
res[k] = v
else:
raise RuntimeError("Unsupported attribute type {}.".format(
self._nativedef.attrs[k]))
return res
def _on_pyop_invocation(k_id, feed, attributes):
if k_id not in Opdef._odlist:
raise RuntimeError(
"Unable to find function id={}. "
"Did you decorate the operator with @onnx_op?.".format(k_id))
op_ = Opdef._odlist[k_id]
rv = op_.body(*feed, **op_.cast_attributes(attributes))
if isinstance(rv, tuple):
# Multiple outputs.
res = []
for r in rv:
res.append(r.shape)
res.append(r.flatten().tolist())
res = tuple(res)
else:
res = (rv.shape, rv.flatten().tolist())
return (k_id, ) + res
def _ensure_opset_domain(model):
op_domain_name = default_opset_domain()
domain_missing = True
for oi_ in model.opset_import:
if oi_.domain == op_domain_name:
domain_missing = False
if domain_missing:
model.opset_import.extend([helper.make_operatorsetid(op_domain_name, 1)])
return model
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
"""
Replace the existing inputs of a model with the new inputs, plus some extra nodes
:param model: The ONNX model loaded as ModelProto
:param target_input: The input name to be replaced
:param extra_nodes: The extra nodes to be added
:param new_inputs: The new input (type: ValueInfoProto) sequence
:return: The ONNX model after modification
"""
graph = model.graph
new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
new_nodes = list(model.graph.node) + extra_nodes
new_graph = helper.make_graph(
new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
new_model = copy.deepcopy(model)
new_model.graph.CopyFrom(new_graph)
return _ensure_opset_domain(new_model)
def hook_model_op(model, node_name, hook_func, input_types):
"""
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
:param model: The ONNX model loaded as ModelProto
:param node_name: The node name where the hook will be installed
:param hook_func: The hook function, callback on the model inference
:param input_types: The input types as a list
:return: The ONNX model with the hook installed
"""
# onnx.shape_inference is very unstable, useless.
# hkd_model = shape_inference.infer_shapes(model)
hkd_model = model
n_idx = 0
hnode, nnode = (None, None)
nodes = list(hkd_model.graph.node)
brkpt_name = node_name + '_hkd'
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
for n_ in nodes:
if n_.name == node_name:
input_names = list(n_.input)
brk_output_name = [i_ + '_hkd' for i_ in input_names]
hnode = onnx.helper.make_node(
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
nnode = n_
del nnode.input[:]
nnode.input.extend(brk_output_name)
break
n_idx += 1
if hnode is None:
raise ValueError("{} is not an operator node name".format(node_name))
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
del hkd_model.graph.node[:]
hkd_model.graph.node.extend(repacked)
Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
return _ensure_opset_domain(hkd_model)
PyCustomOpDef.install_hooker(_on_pyop_invocation)