make onnx package to be optional. (#653)
* putting onnx package to be optional * update the ci.yml * add more message of missing ONNX package
This commit is contained in:
Родитель
fc275e623f
Коммит
b045e66396
|
@ -89,8 +89,7 @@ stages:
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install --upgrade setuptools
|
python -m pip install --upgrade setuptools
|
||||||
python -m pip install onnxruntime==$(ort.version)
|
python -m pip install onnxruntime==$(ort.version)
|
||||||
python -m pip install -r requirements.txt
|
displayName: Install requirements
|
||||||
displayName: Install requirements.txt
|
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
CPU_NUMBER=8 python -m pip install .
|
CPU_NUMBER=8 python -m pip install .
|
||||||
|
@ -283,8 +282,7 @@ stages:
|
||||||
python -m pip install --upgrade setuptools
|
python -m pip install --upgrade setuptools
|
||||||
python -m pip install --upgrade wheel
|
python -m pip install --upgrade wheel
|
||||||
python -m pip install onnxruntime==$(ort.version)
|
python -m pip install onnxruntime==$(ort.version)
|
||||||
python -m pip install -r requirements.txt
|
displayName: Install requirements
|
||||||
displayName: Install requirements.txt
|
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
python -c "import onnxruntime;print(onnxruntime.__version__)"
|
python -c "import onnxruntime;print(onnxruntime.__version__)"
|
||||||
|
@ -419,7 +417,6 @@ stages:
|
||||||
call activate pyenv
|
call activate pyenv
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install onnxruntime==$(ort.version)
|
python -m pip install onnxruntime==$(ort.version)
|
||||||
python -m pip install -r requirements.txt
|
|
||||||
python -m pip install -r requirements-dev.txt
|
python -m pip install -r requirements-dev.txt
|
||||||
displayName: Install requirements{-dev}.txt and cmake python modules
|
displayName: Install requirements{-dev}.txt and cmake python modules
|
||||||
|
|
||||||
|
@ -653,7 +650,6 @@ stages:
|
||||||
python3 -m pip install --upgrade pip; \
|
python3 -m pip install --upgrade pip; \
|
||||||
python3 -m pip install --upgrade setuptools; \
|
python3 -m pip install --upgrade setuptools; \
|
||||||
python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \
|
python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \
|
||||||
python3 -m pip install -r requirements.txt; \
|
|
||||||
python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \
|
python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \
|
||||||
python3 -m pip install $(TORCH_VERSION) ; \
|
python3 -m pip install $(TORCH_VERSION) ; \
|
||||||
python3 -m pip install -r requirements-dev.txt; \
|
python3 -m pip install -r requirements-dev.txt; \
|
||||||
|
|
|
@ -31,12 +31,15 @@ python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.gi
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
## 1. Generate the pre-/post- processing ONNX model
|
## 1. Generation of Pre-/Post-Processing ONNX Model
|
||||||
With onnxruntime-extensions Python package, you can easily get the ONNX processing graph by converting them from Huggingface transformer data processing classes, check the following API for details.
|
The `onnxruntime-extensions` Python package provides a convenient way to generate the ONNX processing graph. This can be achieved by converting the Huggingface transformer data processing classes into the desired format. For more detailed information, please refer to the API below:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
help(onnxruntime_extensions.gen_processing_models)
|
help(onnxruntime_extensions.gen_processing_models)
|
||||||
```
|
```
|
||||||
### NOTE: These data processing model can be merged into other model [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
|
### NOTE:
|
||||||
|
The generation of model processing requires the **ONNX** package to be installed. The data processing models generated in this manner can be merged with other models using the [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
|
||||||
|
|
||||||
## 2. Using Extensions for ONNX Runtime inference
|
## 2. Using Extensions for ONNX Runtime inference
|
||||||
|
|
||||||
### Python
|
### Python
|
||||||
|
|
|
@ -10,37 +10,70 @@ This enables more flexibility and control over model execution, thus expanding t
|
||||||
|
|
||||||
__author__ = "Microsoft"
|
__author__ = "Microsoft"
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'gen_processing_models',
|
|
||||||
'ort_inference',
|
|
||||||
'get_library_path',
|
|
||||||
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
|
|
||||||
'enable_py_op',
|
|
||||||
'expand_onnx_inputs',
|
|
||||||
'hook_model_op',
|
|
||||||
'default_opset_domain',
|
|
||||||
'OrtPyFunction', 'PyOrtFunction',
|
|
||||||
'optimize_model',
|
|
||||||
'make_onnx_model',
|
|
||||||
'ONNXRuntimeError',
|
|
||||||
'hash_64',
|
|
||||||
'__version__',
|
|
||||||
]
|
|
||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from ._ocos import get_library_path
|
from ._ocos import get_library_path
|
||||||
from ._ocos import Opdef, PyCustomOpDef
|
from ._ocos import Opdef, PyCustomOpDef
|
||||||
from ._ocos import hash_64
|
from ._ocos import hash_64
|
||||||
from ._ocos import enable_py_op
|
from ._ocos import enable_py_op
|
||||||
from ._ocos import expand_onnx_inputs
|
|
||||||
from ._ocos import hook_model_op
|
|
||||||
from ._ocos import default_opset_domain
|
from ._ocos import default_opset_domain
|
||||||
from ._cuops import * # noqa
|
|
||||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
|
||||||
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
|
_lib_only = False
|
||||||
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException
|
|
||||||
from .cvt import gen_processing_models
|
try:
|
||||||
|
import onnx # noqa
|
||||||
|
import onnxruntime # noqa
|
||||||
|
except ImportError:
|
||||||
|
_lib_only = True
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_offline_api = [
|
||||||
|
"gen_processing_models",
|
||||||
|
"ort_inference",
|
||||||
|
"OrtPyFunction",
|
||||||
|
"PyOrtFunction",
|
||||||
|
"optimize_model",
|
||||||
|
"make_onnx_model",
|
||||||
|
"ONNXRuntimeError",
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_library_path",
|
||||||
|
"Opdef",
|
||||||
|
"onnx_op",
|
||||||
|
"PyCustomOpDef",
|
||||||
|
"PyOp",
|
||||||
|
"enable_py_op",
|
||||||
|
"expand_onnx_inputs",
|
||||||
|
"hook_model_op",
|
||||||
|
"default_opset_domain",
|
||||||
|
"hash_64",
|
||||||
|
"__version__",
|
||||||
|
]
|
||||||
|
|
||||||
# rename the implementation with a more formal name
|
# rename the implementation with a more formal name
|
||||||
onnx_op = Opdef.declare
|
onnx_op = Opdef.declare
|
||||||
PyOp = PyCustomOpDef
|
PyOp = PyCustomOpDef
|
||||||
|
|
||||||
|
|
||||||
|
if _lib_only:
|
||||||
|
|
||||||
|
def _unimplemented(*args, **kwargs):
|
||||||
|
raise NotImplementedError("ONNX or ONNX Runtime is not installed")
|
||||||
|
|
||||||
|
gen_processing_models = _unimplemented
|
||||||
|
OrtPyFunction = _unimplemented
|
||||||
|
ort_inference = _unimplemented
|
||||||
|
|
||||||
|
else:
|
||||||
|
__all__ += _offline_api
|
||||||
|
|
||||||
|
from ._cuops import * # noqa
|
||||||
|
from ._ortapi2 import hook_model_op
|
||||||
|
from ._ortapi2 import expand_onnx_inputs
|
||||||
|
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
|
||||||
|
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||||
|
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException # noqa
|
||||||
|
from .cvt import gen_processing_models
|
||||||
|
|
|
@ -7,34 +7,36 @@ _ocos.py: PythonOp implementation
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
|
||||||
import glob
|
import glob
|
||||||
import onnx
|
|
||||||
from onnx import helper
|
|
||||||
|
|
||||||
|
|
||||||
def _search_cuda_dir():
|
def _search_cuda_dir():
|
||||||
paths = os.getenv('PATH', '').split(os.pathsep)
|
paths = os.getenv("PATH", "").split(os.pathsep)
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for filename in glob.glob(os.path.join(path, 'cudart64*.dll')):
|
for filename in glob.glob(os.path.join(path, "cudart64*.dll")):
|
||||||
return os.path.dirname(filename)
|
return os.path.dirname(filename)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
from . import _version # noqa: E402
|
from . import _version # noqa: E402
|
||||||
if hasattr(_version, 'cuda'):
|
|
||||||
|
if hasattr(_version, "cuda"):
|
||||||
cuda_path = _search_cuda_dir()
|
cuda_path = _search_cuda_dir()
|
||||||
if cuda_path is None:
|
if cuda_path is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Cannot locate CUDA directory in the environment variable for GPU package")
|
||||||
"Cannot locate CUDA directory in the environment variable for GPU package")
|
|
||||||
|
|
||||||
os.add_dll_directory(cuda_path)
|
os.add_dll_directory(cuda_path)
|
||||||
|
|
||||||
|
|
||||||
from ._extensions_pydll import ( # noqa
|
from ._extensions_pydll import ( # noqa
|
||||||
PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain)
|
PyCustomOpDef,
|
||||||
|
enable_py_op,
|
||||||
|
add_custom_op,
|
||||||
|
hash_64,
|
||||||
|
default_opset_domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_library_path():
|
def get_library_path():
|
||||||
|
@ -42,12 +44,11 @@ def get_library_path():
|
||||||
The custom operator library binary path
|
The custom operator library binary path
|
||||||
:return: A string of this library path.
|
:return: A string of this library path.
|
||||||
"""
|
"""
|
||||||
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
|
mod = sys.modules["onnxruntime_extensions._extensions_pydll"]
|
||||||
return mod.__file__
|
return mod.__file__
|
||||||
|
|
||||||
|
|
||||||
class Opdef:
|
class Opdef:
|
||||||
|
|
||||||
_odlist = {}
|
_odlist = {}
|
||||||
|
|
||||||
def __init__(self, op_type, func):
|
def __init__(self, op_type, func):
|
||||||
|
@ -57,14 +58,14 @@ class Opdef:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def declare(*args, **kwargs):
|
def declare(*args, **kwargs):
|
||||||
if len(args) > 0 and hasattr(args[0], '__call__'):
|
if len(args) > 0 and hasattr(args[0], "__call__"):
|
||||||
raise RuntimeError("Unexpected arguments {}.".format(args))
|
raise RuntimeError("Unexpected arguments {}.".format(args))
|
||||||
# return Opdef._create(args[0])
|
# return Opdef._create(args[0])
|
||||||
return lambda f: Opdef.create(f, *args, **kwargs)
|
return lambda f: Opdef.create(f, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(func, *args, **kwargs):
|
def create(func, *args, **kwargs):
|
||||||
name = kwargs.get('op_type', None)
|
name = kwargs.get("op_type", None)
|
||||||
op_type = name or func.__name__
|
op_type = name or func.__name__
|
||||||
opdef = Opdef(op_type, func)
|
opdef = Opdef(op_type, func)
|
||||||
od_id = id(opdef)
|
od_id = id(opdef)
|
||||||
|
@ -76,15 +77,15 @@ class Opdef:
|
||||||
opdef._nativedef.op_type = op_type
|
opdef._nativedef.op_type = op_type
|
||||||
opdef._nativedef.obj_id = od_id
|
opdef._nativedef.obj_id = od_id
|
||||||
|
|
||||||
inputs = kwargs.get('inputs', None)
|
inputs = kwargs.get("inputs", None)
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
inputs = [PyCustomOpDef.dt_float]
|
inputs = [PyCustomOpDef.dt_float]
|
||||||
opdef._nativedef.input_types = inputs
|
opdef._nativedef.input_types = inputs
|
||||||
outputs = kwargs.get('outputs', None)
|
outputs = kwargs.get("outputs", None)
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
outputs = [PyCustomOpDef.dt_float]
|
outputs = [PyCustomOpDef.dt_float]
|
||||||
opdef._nativedef.output_types = outputs
|
opdef._nativedef.output_types = outputs
|
||||||
attrs = kwargs.get('attrs', None)
|
attrs = kwargs.get("attrs", None)
|
||||||
if attrs is None:
|
if attrs is None:
|
||||||
attrs = {}
|
attrs = {}
|
||||||
elif isinstance(attrs, (list, tuple)):
|
elif isinstance(attrs, (list, tuple)):
|
||||||
|
@ -106,16 +107,15 @@ class Opdef:
|
||||||
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
|
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
|
||||||
res[k] = v
|
res[k] = v
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported attribute type {}.".format(
|
raise RuntimeError("Unsupported attribute type {}.".format(self._nativedef.attrs[k]))
|
||||||
self._nativedef.attrs[k]))
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _on_pyop_invocation(k_id, feed, attributes):
|
def _on_pyop_invocation(k_id, feed, attributes):
|
||||||
if k_id not in Opdef._odlist:
|
if k_id not in Opdef._odlist:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Unable to find function id={}. "
|
"Unable to find function id={}. " "Did you decorate the operator with @onnx_op?.".format(k_id)
|
||||||
"Did you decorate the operator with @onnx_op?.".format(k_id))
|
)
|
||||||
op_ = Opdef._odlist[k_id]
|
op_ = Opdef._odlist[k_id]
|
||||||
rv = op_.body(*feed, **op_.cast_attributes(attributes))
|
rv = op_.body(*feed, **op_.cast_attributes(attributes))
|
||||||
if isinstance(rv, tuple):
|
if isinstance(rv, tuple):
|
||||||
|
@ -127,86 +127,7 @@ def _on_pyop_invocation(k_id, feed, attributes):
|
||||||
res = tuple(res)
|
res = tuple(res)
|
||||||
else:
|
else:
|
||||||
res = (rv.shape, rv.flatten().tolist())
|
res = (rv.shape, rv.flatten().tolist())
|
||||||
return (k_id, ) + res
|
return (k_id,) + res
|
||||||
|
|
||||||
|
|
||||||
def _ensure_opset_domain(model):
|
|
||||||
op_domain_name = default_opset_domain()
|
|
||||||
domain_missing = True
|
|
||||||
for oi_ in model.opset_import:
|
|
||||||
if oi_.domain == op_domain_name:
|
|
||||||
domain_missing = False
|
|
||||||
|
|
||||||
if domain_missing:
|
|
||||||
model.opset_import.extend(
|
|
||||||
[helper.make_operatorsetid(op_domain_name, 1)])
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
|
|
||||||
"""
|
|
||||||
Replace the existing inputs of a model with the new inputs, plus some extra nodes
|
|
||||||
:param model: The ONNX model loaded as ModelProto
|
|
||||||
:param target_input: The input name to be replaced
|
|
||||||
:param extra_nodes: The extra nodes to be added
|
|
||||||
:param new_inputs: The new input (type: ValueInfoProto) sequence
|
|
||||||
:return: The ONNX model after modification
|
|
||||||
"""
|
|
||||||
graph = model.graph
|
|
||||||
new_inputs = [n for n in graph.input if n.name !=
|
|
||||||
target_input] + new_inputs
|
|
||||||
new_nodes = list(model.graph.node) + extra_nodes
|
|
||||||
new_graph = helper.make_graph(
|
|
||||||
new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
|
|
||||||
|
|
||||||
new_model = copy.deepcopy(model)
|
|
||||||
new_model.graph.CopyFrom(new_graph)
|
|
||||||
|
|
||||||
return _ensure_opset_domain(new_model)
|
|
||||||
|
|
||||||
|
|
||||||
def hook_model_op(model, node_name, hook_func, input_types):
|
|
||||||
"""
|
|
||||||
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
|
|
||||||
:param model: The ONNX model loaded as ModelProto
|
|
||||||
:param node_name: The node name where the hook will be installed
|
|
||||||
:param hook_func: The hook function, callback on the model inference
|
|
||||||
:param input_types: The input types as a list
|
|
||||||
:return: The ONNX model with the hook installed
|
|
||||||
"""
|
|
||||||
|
|
||||||
# onnx.shape_inference is very unstable, useless.
|
|
||||||
# hkd_model = shape_inference.infer_shapes(model)
|
|
||||||
hkd_model = model
|
|
||||||
|
|
||||||
n_idx = 0
|
|
||||||
hnode, nnode = (None, None)
|
|
||||||
nodes = list(hkd_model.graph.node)
|
|
||||||
brkpt_name = node_name + '_hkd'
|
|
||||||
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
|
|
||||||
for n_ in nodes:
|
|
||||||
if n_.name == node_name:
|
|
||||||
input_names = list(n_.input)
|
|
||||||
brk_output_name = [i_ + '_hkd' for i_ in input_names]
|
|
||||||
hnode = onnx.helper.make_node(
|
|
||||||
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
|
|
||||||
nnode = n_
|
|
||||||
del nnode.input[:]
|
|
||||||
nnode.input.extend(brk_output_name)
|
|
||||||
break
|
|
||||||
n_idx += 1
|
|
||||||
|
|
||||||
if hnode is None:
|
|
||||||
raise ValueError("{} is not an operator node name".format(node_name))
|
|
||||||
|
|
||||||
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
|
|
||||||
del hkd_model.graph.node[:]
|
|
||||||
hkd_model.graph.node.extend(repacked)
|
|
||||||
|
|
||||||
Opdef.create(hook_func, op_type=optype_name,
|
|
||||||
inputs=input_types, outputs=input_types)
|
|
||||||
return _ensure_opset_domain(hkd_model)
|
|
||||||
|
|
||||||
|
|
||||||
PyCustomOpDef.install_hooker(_on_pyop_invocation)
|
PyCustomOpDef.install_hooker(_on_pyop_invocation)
|
||||||
|
|
|
@ -7,8 +7,9 @@
|
||||||
_ortapi2.py: ONNXRuntime-Extensions Python API
|
_ortapi2.py: ONNXRuntime-Extensions Python API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
from ._ocos import default_opset_domain, get_library_path, Opdef
|
||||||
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
||||||
|
|
||||||
_ort_check_passed = False
|
_ort_check_passed = False
|
||||||
|
@ -25,6 +26,82 @@ if not _ort_check_passed:
|
||||||
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
|
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_opset_domain(model):
|
||||||
|
op_domain_name = default_opset_domain()
|
||||||
|
domain_missing = True
|
||||||
|
for oi_ in model.opset_import:
|
||||||
|
if oi_.domain == op_domain_name:
|
||||||
|
domain_missing = False
|
||||||
|
|
||||||
|
if domain_missing:
|
||||||
|
model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def hook_model_op(model, node_name, hook_func, input_types):
|
||||||
|
"""
|
||||||
|
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
|
||||||
|
:param model: The ONNX model loaded as ModelProto
|
||||||
|
:param node_name: The node name where the hook will be installed
|
||||||
|
:param hook_func: The hook function, callback on the model inference
|
||||||
|
:param input_types: The input types as a list
|
||||||
|
:return: The ONNX model with the hook installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
# onnx.shape_inference is very unstable, useless.
|
||||||
|
# hkd_model = shape_inference.infer_shapes(model)
|
||||||
|
hkd_model = model
|
||||||
|
|
||||||
|
n_idx = 0
|
||||||
|
hnode, nnode = (None, None)
|
||||||
|
nodes = list(hkd_model.graph.node)
|
||||||
|
brkpt_name = node_name + "_hkd"
|
||||||
|
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
|
||||||
|
for n_ in nodes:
|
||||||
|
if n_.name == node_name:
|
||||||
|
input_names = list(n_.input)
|
||||||
|
brk_output_name = [i_ + "_hkd" for i_ in input_names]
|
||||||
|
hnode = onnx.helper.make_node(
|
||||||
|
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()
|
||||||
|
)
|
||||||
|
nnode = n_
|
||||||
|
del nnode.input[:]
|
||||||
|
nnode.input.extend(brk_output_name)
|
||||||
|
break
|
||||||
|
n_idx += 1
|
||||||
|
|
||||||
|
if hnode is None:
|
||||||
|
raise ValueError("{} is not an operator node name".format(node_name))
|
||||||
|
|
||||||
|
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :]
|
||||||
|
del hkd_model.graph.node[:]
|
||||||
|
hkd_model.graph.node.extend(repacked)
|
||||||
|
|
||||||
|
Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
|
||||||
|
return _ensure_opset_domain(hkd_model)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
|
||||||
|
"""
|
||||||
|
Replace the existing inputs of a model with the new inputs, plus some extra nodes
|
||||||
|
:param model: The ONNX model loaded as ModelProto
|
||||||
|
:param target_input: The input name to be replaced
|
||||||
|
:param extra_nodes: The extra nodes to be added
|
||||||
|
:param new_inputs: The new input (type: ValueInfoProto) sequence
|
||||||
|
:return: The ONNX model after modification
|
||||||
|
"""
|
||||||
|
graph = model.graph
|
||||||
|
new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
|
||||||
|
new_nodes = list(model.graph.node) + extra_nodes
|
||||||
|
new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
|
||||||
|
|
||||||
|
new_model = copy.deepcopy(model)
|
||||||
|
new_model.graph.CopyFrom(new_graph)
|
||||||
|
|
||||||
|
return _ensure_opset_domain(new_model)
|
||||||
|
|
||||||
|
|
||||||
def get_opset_version_from_ort():
|
def get_opset_version_from_ort():
|
||||||
_ORT_OPSET_SUPPORT_TABLE = {
|
_ORT_OPSET_SUPPORT_TABLE = {
|
||||||
"1.5": 11,
|
"1.5": 11,
|
||||||
|
@ -37,10 +114,10 @@ def get_opset_version_from_ort():
|
||||||
"1.12": 17,
|
"1.12": 17,
|
||||||
"1.13": 17,
|
"1.13": 17,
|
||||||
"1.14": 18,
|
"1.14": 18,
|
||||||
"1.15": 18
|
"1.15": 18,
|
||||||
}
|
}
|
||||||
|
|
||||||
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
ort_ver_string = ".".join(_ort.__version__.split(".")[0:2])
|
||||||
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
|
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
|
||||||
if ort_ver_string > max_ver:
|
if ort_ver_string > max_ver:
|
||||||
ort_ver_string = max_ver
|
ort_ver_string = max_ver
|
||||||
|
@ -50,19 +127,18 @@ def get_opset_version_from_ort():
|
||||||
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
|
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
|
||||||
if opset_version == 0:
|
if opset_version == 0:
|
||||||
opset_version = get_opset_version_from_ort()
|
opset_version = get_opset_version_from_ort()
|
||||||
fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
|
fn_mm = (
|
||||||
) else onnx.helper.make_model
|
onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model
|
||||||
model = fn_mm(graph, opset_imports=[
|
)
|
||||||
onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
|
model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)])
|
||||||
model.opset_import.extend(
|
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
||||||
[onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class OrtPyFunction:
|
class OrtPyFunction:
|
||||||
"""
|
"""
|
||||||
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
|
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
|
||||||
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
|
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
|
||||||
standard Python function. The order of the function arguments correlates directly with
|
standard Python function. The order of the function arguments correlates directly with
|
||||||
the sequence of the input/output in the ONNX graph.
|
the sequence of the input/output in the ONNX graph.
|
||||||
"""
|
"""
|
||||||
|
@ -78,10 +154,10 @@ class OrtPyFunction:
|
||||||
self._onnx_model = None
|
self._onnx_model = None
|
||||||
self.ort_session = None
|
self.ort_session = None
|
||||||
self.default_inputs = {}
|
self.default_inputs = {}
|
||||||
self.execution_providers = ['CPUExecutionProvider']
|
self.execution_providers = ["CPUExecutionProvider"]
|
||||||
if not cpu_only:
|
if not cpu_only:
|
||||||
if _ort.get_device() == 'GPU':
|
if _ort.get_device() == "GPU":
|
||||||
self.execution_providers = ['CUDAExecutionProvider']
|
self.execution_providers = ["CUDAExecutionProvider"]
|
||||||
self.extra_session_options = {}
|
self.extra_session_options = {}
|
||||||
mpath = None
|
mpath = None
|
||||||
if isinstance(path_or_model, str):
|
if isinstance(path_or_model, str):
|
||||||
|
@ -99,8 +175,8 @@ class OrtPyFunction:
|
||||||
|
|
||||||
def add_default_input(self, **kwargs):
|
def add_default_input(self, **kwargs):
|
||||||
inputs = {
|
inputs = {
|
||||||
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
|
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8)
|
||||||
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
|
for ky_, val_ in kwargs.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
self.default_inputs.update(inputs)
|
self.default_inputs.update(inputs)
|
||||||
|
@ -124,30 +200,29 @@ class OrtPyFunction:
|
||||||
self._oxml = oxml
|
self._oxml = oxml
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
self.ort_session = _ort.InferenceSession(
|
self.ort_session = _ort.InferenceSession(
|
||||||
model_path, self.get_ort_session_options(),
|
model_path, self.get_ort_session_options(), self.execution_providers
|
||||||
self.execution_providers)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _ensure_ort_session(self):
|
def _ensure_ort_session(self):
|
||||||
if self.ort_session is None:
|
if self.ort_session is None:
|
||||||
sess = _ort.InferenceSession(
|
sess = _ort.InferenceSession(
|
||||||
self.onnx_model.SerializeToString(), self.get_ort_session_options(),
|
self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers
|
||||||
self.execution_providers)
|
)
|
||||||
self.ort_session = sess
|
self.ort_session = sess
|
||||||
|
|
||||||
return self.ort_session
|
return self.ort_session
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_kwarg_device(kwargs):
|
def _get_kwarg_device(kwargs):
|
||||||
cpuonly = kwargs.get('cpu_only', None)
|
cpuonly = kwargs.get("cpu_only", None)
|
||||||
if cpuonly is not None:
|
if cpuonly is not None:
|
||||||
del kwargs['cpu_only']
|
del kwargs["cpu_only"]
|
||||||
return cpuonly
|
return cpuonly
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_customop(cls, op_type, *args, **kwargs):
|
def from_customop(cls, op_type, *args, **kwargs):
|
||||||
return (cls(cpu_only=cls._get_kwarg_device(kwargs))
|
return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
|
||||||
.create_from_customop(op_type, *args, **kwargs))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, path_or_model, *args, **kwargs):
|
def from_model(cls, path_or_model, *args, **kwargs):
|
||||||
|
@ -165,9 +240,9 @@ class OrtPyFunction:
|
||||||
x = args[idx]
|
x = args[idx]
|
||||||
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
||||||
# numpy by default is int32 in some platforms, sometimes it is int64.
|
# numpy by default is int32 in some platforms, sometimes it is int64.
|
||||||
feed[i_.name] = \
|
feed[i_.name] = (
|
||||||
ts_x.astype(
|
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||||
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
feed.update(kwargs)
|
feed.update(kwargs)
|
||||||
|
@ -175,8 +250,7 @@ class OrtPyFunction:
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
self._ensure_ort_session()
|
self._ensure_ort_session()
|
||||||
outputs = self.ort_session.run(
|
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
|
||||||
None, self._argument_map(*args, **kwargs))
|
|
||||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,8 +265,9 @@ def optimize_model(model_or_file, output_file):
|
||||||
sess_options = OrtPyFunction().get_ort_session_options()
|
sess_options = OrtPyFunction().get_ort_session_options()
|
||||||
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
sess_options.optimized_model_filepath = output_file
|
sess_options.optimized_model_filepath = output_file
|
||||||
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
|
_ort.InferenceSession(
|
||||||
else model_or_file.SerializeToString(), sess_options)
|
model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
|
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# include requirements.txt so pip has context to avoid installing incompatible dependencies
|
|
||||||
-r requirements.txt
|
|
||||||
pytest
|
pytest
|
||||||
# multiple versions of onnxruntime are supported, but only one can be installed at a time
|
onnx >= 1.9.0
|
||||||
protobuf < 4.0.0
|
protobuf < 4.0.0
|
||||||
|
# multiple versions of onnxruntime are supported, but only one can be installed at a time
|
||||||
onnxruntime >=1.12.0
|
onnxruntime >=1.12.0
|
||||||
transformers >=4.9.2
|
transformers >=4.9.2
|
||||||
tensorflow_text >=2.5.0;python_version < '3.11'
|
tensorflow_text >=2.5.0;python_version < '3.11'
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
onnx>=1.9.0
|
|
77
setup.py
77
setup.py
|
@ -9,14 +9,13 @@ import sys
|
||||||
import pathlib
|
import pathlib
|
||||||
import setuptools
|
import setuptools
|
||||||
|
|
||||||
from textwrap import dedent
|
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
TOP_DIR = os.path.dirname(__file__) or os.getcwd()
|
TOP_DIR = os.path.dirname(__file__) or os.getcwd()
|
||||||
PACKAGE_NAME = 'onnxruntime_extensions'
|
PACKAGE_NAME = "onnxruntime_extensions"
|
||||||
|
|
||||||
# setup.py cannot be debugged in pip command line, so the command classes are refactored into another file
|
# setup.py cannot be debugged in pip command line, so the command classes are refactored into another file
|
||||||
cmds_dir = pathlib.Path(TOP_DIR) / '.pyproject'
|
cmds_dir = pathlib.Path(TOP_DIR) / ".pyproject"
|
||||||
sys.path.append(str(cmds_dir))
|
sys.path.append(str(cmds_dir))
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import cmdclass as _cmds # noqa: E402
|
import cmdclass as _cmds # noqa: E402
|
||||||
|
@ -24,62 +23,50 @@ import cmdclass as _cmds # noqa: E402
|
||||||
_cmds.prepare_env(TOP_DIR)
|
_cmds.prepare_env(TOP_DIR)
|
||||||
|
|
||||||
|
|
||||||
def read_requirements():
|
|
||||||
with open(os.path.join(TOP_DIR, "requirements.txt"), "r", encoding="utf-8") as f:
|
|
||||||
requirements = [_ for _ in [dedent(_) for _ in f.readlines()] if _ is not None]
|
|
||||||
return requirements
|
|
||||||
|
|
||||||
|
|
||||||
# read version from the package file.
|
# read version from the package file.
|
||||||
def read_version():
|
def read_version():
|
||||||
version_str = '1.0.0'
|
version_str = "0.1.0"
|
||||||
with (open(os.path.join(TOP_DIR, 'version.txt'), "r")) as f:
|
with open(os.path.join(TOP_DIR, "version.txt"), "r") as f:
|
||||||
version_str = f.readline().strip()
|
version_str = f.readline().strip()
|
||||||
|
|
||||||
# special handling for Onebranch building
|
# special handling for Onebranch building
|
||||||
if os.getenv('BUILD_SOURCEBRANCHNAME', "").startswith('rel-'):
|
if os.getenv("BUILD_SOURCEBRANCHNAME", "").startswith("rel-"):
|
||||||
return version_str
|
return version_str
|
||||||
|
|
||||||
# is it a dev build or release?
|
# is it a dev build or release?
|
||||||
rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(
|
rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(os.path.join(TOP_DIR, ".git")) else (True, None)
|
||||||
os.path.join(TOP_DIR, '.git')) else (True, None)
|
|
||||||
|
|
||||||
if rel_br:
|
if rel_br:
|
||||||
return version_str
|
return version_str
|
||||||
|
|
||||||
build_id = os.getenv('BUILD_BUILDID', None)
|
build_id = os.getenv("BUILD_BUILDID", None)
|
||||||
if build_id is not None:
|
if build_id is not None:
|
||||||
version_str += '.{}'.format(build_id)
|
version_str += ".{}".format(build_id)
|
||||||
else:
|
else:
|
||||||
version_str += '+' + cid[:7]
|
version_str += "+" + cid[:7]
|
||||||
return version_str
|
return version_str
|
||||||
|
|
||||||
|
|
||||||
def write_py_version(ext_version):
|
def write_py_version(ext_version):
|
||||||
text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n",
|
text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", '__version__ = "{}"\n'.format(ext_version)]
|
||||||
"__version__ = \"{}\"\n".format(ext_version)]
|
with open(os.path.join(TOP_DIR, "onnxruntime_extensions/_version.py"), "w") as _fver:
|
||||||
with (open(os.path.join(TOP_DIR, 'onnxruntime_extensions/_version.py'), "w")) as _fver:
|
|
||||||
_fver.writelines(text)
|
_fver.writelines(text)
|
||||||
|
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [setuptools.extension.Extension(name=str("onnxruntime_extensions._extensions_pydll"), sources=[])]
|
||||||
setuptools.extension.Extension(
|
|
||||||
name=str('onnxruntime_extensions._extensions_pydll'),
|
|
||||||
sources=[])
|
|
||||||
]
|
|
||||||
|
|
||||||
packages = find_packages()
|
packages = find_packages()
|
||||||
package_dir = {k: os.path.join('.', k.replace(".", "/")) for k in packages}
|
package_dir = {k: os.path.join(".", k.replace(".", "/")) for k in packages}
|
||||||
package_data = {
|
package_data = {
|
||||||
"onnxruntime_extensions": ["*.so", "*.pyd"],
|
"onnxruntime_extensions": ["*.so", "*.pyd"],
|
||||||
}
|
}
|
||||||
|
|
||||||
long_description = ''
|
long_description = ""
|
||||||
with open(os.path.join(TOP_DIR, "README.md"), 'r', encoding="utf-8") as _f:
|
with open(os.path.join(TOP_DIR, "README.md"), "r", encoding="utf-8") as _f:
|
||||||
long_description += _f.read()
|
long_description += _f.read()
|
||||||
start_pos = long_description.find('# Introduction')
|
start_pos = long_description.find("# Introduction")
|
||||||
start_pos = 0 if start_pos < 0 else start_pos
|
start_pos = 0 if start_pos < 0 else start_pos
|
||||||
end_pos = long_description.find('# Contributing')
|
end_pos = long_description.find("# Contributing")
|
||||||
long_description = long_description[start_pos:end_pos]
|
long_description = long_description[start_pos:end_pos]
|
||||||
ortx_version = read_version()
|
ortx_version = read_version()
|
||||||
write_py_version(ortx_version)
|
write_py_version(ortx_version)
|
||||||
|
@ -92,25 +79,25 @@ setup(
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
description="ONNXRuntime Extensions",
|
description="ONNXRuntime Extensions",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type="text/markdown",
|
||||||
license='MIT License',
|
license="MIT License",
|
||||||
author='Microsoft Corporation',
|
author="Microsoft Corporation",
|
||||||
author_email='onnxruntime@microsoft.com',
|
author_email="onnxruntime@microsoft.com",
|
||||||
url='https://github.com/microsoft/onnxruntime-extensions',
|
url="https://github.com/microsoft/onnxruntime-extensions",
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass=_cmds.ortx_cmdclass,
|
cmdclass=_cmds.ortx_cmdclass,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=read_requirements(),
|
install_requires=[],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 4 - Beta',
|
"Development Status :: 4 - Beta",
|
||||||
'Environment :: Console',
|
"Environment :: Console",
|
||||||
'Intended Audience :: Developers',
|
"Intended Audience :: Developers",
|
||||||
'Operating System :: MacOS :: MacOS X',
|
"Operating System :: MacOS :: MacOS X",
|
||||||
'Operating System :: Microsoft :: Windows',
|
"Operating System :: Microsoft :: Windows",
|
||||||
'Operating System :: POSIX :: Linux',
|
"Operating System :: POSIX :: Linux",
|
||||||
"Programming Language :: C++",
|
"Programming Language :: C++",
|
||||||
'Programming Language :: Python',
|
"Programming Language :: Python",
|
||||||
"Programming Language :: Python :: Implementation :: CPython",
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
'License :: OSI Approved :: MIT License'
|
"License :: OSI Approved :: MIT License",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче