make onnx package to be optional. (#653)

* putting onnx package to be optional

* update the ci.yml

* add more message of missing ONNX package
This commit is contained in:
Wenbing Li 2024-02-15 14:09:04 -08:00 коммит произвёл GitHub
Родитель fc275e623f
Коммит b045e66396
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 225 добавлений и 212 удалений

Просмотреть файл

@ -89,8 +89,7 @@ stages:
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install --upgrade setuptools python -m pip install --upgrade setuptools
python -m pip install onnxruntime==$(ort.version) python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt displayName: Install requirements
displayName: Install requirements.txt
- script: | - script: |
CPU_NUMBER=8 python -m pip install . CPU_NUMBER=8 python -m pip install .
@ -283,8 +282,7 @@ stages:
python -m pip install --upgrade setuptools python -m pip install --upgrade setuptools
python -m pip install --upgrade wheel python -m pip install --upgrade wheel
python -m pip install onnxruntime==$(ort.version) python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt displayName: Install requirements
displayName: Install requirements.txt
- script: | - script: |
python -c "import onnxruntime;print(onnxruntime.__version__)" python -c "import onnxruntime;print(onnxruntime.__version__)"
@ -419,7 +417,6 @@ stages:
call activate pyenv call activate pyenv
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install onnxruntime==$(ort.version) python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt python -m pip install -r requirements-dev.txt
displayName: Install requirements{-dev}.txt and cmake python modules displayName: Install requirements{-dev}.txt and cmake python modules
@ -653,7 +650,6 @@ stages:
python3 -m pip install --upgrade pip; \ python3 -m pip install --upgrade pip; \
python3 -m pip install --upgrade setuptools; \ python3 -m pip install --upgrade setuptools; \
python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \ python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \
python3 -m pip install -r requirements.txt; \
python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \ python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \
python3 -m pip install $(TORCH_VERSION) ; \ python3 -m pip install $(TORCH_VERSION) ; \
python3 -m pip install -r requirements-dev.txt; \ python3 -m pip install -r requirements-dev.txt; \

Просмотреть файл

@ -31,12 +31,15 @@ python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.gi
## Usage ## Usage
## 1. Generate the pre-/post- processing ONNX model ## 1. Generation of Pre-/Post-Processing ONNX Model
With onnxruntime-extensions Python package, you can easily get the ONNX processing graph by converting them from Huggingface transformer data processing classes, check the following API for details. The `onnxruntime-extensions` Python package provides a convenient way to generate the ONNX processing graph. This can be achieved by converting the Huggingface transformer data processing classes into the desired format. For more detailed information, please refer to the API below:
```python ```python
help(onnxruntime_extensions.gen_processing_models) help(onnxruntime_extensions.gen_processing_models)
``` ```
### NOTE: These data processing model can be merged into other model [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed. ### NOTE:
The generation of model processing requires the **ONNX** package to be installed. The data processing models generated in this manner can be merged with other models using the [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
## 2. Using Extensions for ONNX Runtime inference ## 2. Using Extensions for ONNX Runtime inference
### Python ### Python

Просмотреть файл

@ -10,37 +10,70 @@ This enables more flexibility and control over model execution, thus expanding t
__author__ = "Microsoft" __author__ = "Microsoft"
__all__ = [
'gen_processing_models',
'ort_inference',
'get_library_path',
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
'enable_py_op',
'expand_onnx_inputs',
'hook_model_op',
'default_opset_domain',
'OrtPyFunction', 'PyOrtFunction',
'optimize_model',
'make_onnx_model',
'ONNXRuntimeError',
'hash_64',
'__version__',
]
from ._version import __version__ from ._version import __version__
from ._ocos import get_library_path from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef from ._ocos import Opdef, PyCustomOpDef
from ._ocos import hash_64 from ._ocos import hash_64
from ._ocos import enable_py_op from ._ocos import enable_py_op
from ._ocos import expand_onnx_inputs
from ._ocos import hook_model_op
from ._ocos import default_opset_domain from ._ocos import default_opset_domain
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model _lib_only = False
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException
from .cvt import gen_processing_models try:
import onnx # noqa
import onnxruntime # noqa
except ImportError:
_lib_only = True
pass
_offline_api = [
"gen_processing_models",
"ort_inference",
"OrtPyFunction",
"PyOrtFunction",
"optimize_model",
"make_onnx_model",
"ONNXRuntimeError",
]
__all__ = [
"get_library_path",
"Opdef",
"onnx_op",
"PyCustomOpDef",
"PyOp",
"enable_py_op",
"expand_onnx_inputs",
"hook_model_op",
"default_opset_domain",
"hash_64",
"__version__",
]
# rename the implementation with a more formal name # rename the implementation with a more formal name
onnx_op = Opdef.declare onnx_op = Opdef.declare
PyOp = PyCustomOpDef PyOp = PyCustomOpDef
if _lib_only:
def _unimplemented(*args, **kwargs):
raise NotImplementedError("ONNX or ONNX Runtime is not installed")
gen_processing_models = _unimplemented
OrtPyFunction = _unimplemented
ort_inference = _unimplemented
else:
__all__ += _offline_api
from ._cuops import * # noqa
from ._ortapi2 import hook_model_op
from ._ortapi2 import expand_onnx_inputs
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException # noqa
from .cvt import gen_processing_models

Просмотреть файл

@ -7,34 +7,36 @@ _ocos.py: PythonOp implementation
""" """
import os import os
import sys import sys
import copy
import glob import glob
import onnx
from onnx import helper
def _search_cuda_dir(): def _search_cuda_dir():
paths = os.getenv('PATH', '').split(os.pathsep) paths = os.getenv("PATH", "").split(os.pathsep)
for path in paths: for path in paths:
for filename in glob.glob(os.path.join(path, 'cudart64*.dll')): for filename in glob.glob(os.path.join(path, "cudart64*.dll")):
return os.path.dirname(filename) return os.path.dirname(filename)
return None return None
if sys.platform == 'win32': if sys.platform == "win32":
from . import _version # noqa: E402 from . import _version # noqa: E402
if hasattr(_version, 'cuda'):
if hasattr(_version, "cuda"):
cuda_path = _search_cuda_dir() cuda_path = _search_cuda_dir()
if cuda_path is None: if cuda_path is None:
raise RuntimeError( raise RuntimeError("Cannot locate CUDA directory in the environment variable for GPU package")
"Cannot locate CUDA directory in the environment variable for GPU package")
os.add_dll_directory(cuda_path) os.add_dll_directory(cuda_path)
from ._extensions_pydll import ( # noqa from ._extensions_pydll import ( # noqa
PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain) PyCustomOpDef,
enable_py_op,
add_custom_op,
hash_64,
default_opset_domain,
)
def get_library_path(): def get_library_path():
@ -42,12 +44,11 @@ def get_library_path():
The custom operator library binary path The custom operator library binary path
:return: A string of this library path. :return: A string of this library path.
""" """
mod = sys.modules['onnxruntime_extensions._extensions_pydll'] mod = sys.modules["onnxruntime_extensions._extensions_pydll"]
return mod.__file__ return mod.__file__
class Opdef: class Opdef:
_odlist = {} _odlist = {}
def __init__(self, op_type, func): def __init__(self, op_type, func):
@ -57,14 +58,14 @@ class Opdef:
@staticmethod @staticmethod
def declare(*args, **kwargs): def declare(*args, **kwargs):
if len(args) > 0 and hasattr(args[0], '__call__'): if len(args) > 0 and hasattr(args[0], "__call__"):
raise RuntimeError("Unexpected arguments {}.".format(args)) raise RuntimeError("Unexpected arguments {}.".format(args))
# return Opdef._create(args[0]) # return Opdef._create(args[0])
return lambda f: Opdef.create(f, *args, **kwargs) return lambda f: Opdef.create(f, *args, **kwargs)
@staticmethod @staticmethod
def create(func, *args, **kwargs): def create(func, *args, **kwargs):
name = kwargs.get('op_type', None) name = kwargs.get("op_type", None)
op_type = name or func.__name__ op_type = name or func.__name__
opdef = Opdef(op_type, func) opdef = Opdef(op_type, func)
od_id = id(opdef) od_id = id(opdef)
@ -76,15 +77,15 @@ class Opdef:
opdef._nativedef.op_type = op_type opdef._nativedef.op_type = op_type
opdef._nativedef.obj_id = od_id opdef._nativedef.obj_id = od_id
inputs = kwargs.get('inputs', None) inputs = kwargs.get("inputs", None)
if inputs is None: if inputs is None:
inputs = [PyCustomOpDef.dt_float] inputs = [PyCustomOpDef.dt_float]
opdef._nativedef.input_types = inputs opdef._nativedef.input_types = inputs
outputs = kwargs.get('outputs', None) outputs = kwargs.get("outputs", None)
if outputs is None: if outputs is None:
outputs = [PyCustomOpDef.dt_float] outputs = [PyCustomOpDef.dt_float]
opdef._nativedef.output_types = outputs opdef._nativedef.output_types = outputs
attrs = kwargs.get('attrs', None) attrs = kwargs.get("attrs", None)
if attrs is None: if attrs is None:
attrs = {} attrs = {}
elif isinstance(attrs, (list, tuple)): elif isinstance(attrs, (list, tuple)):
@ -106,16 +107,15 @@ class Opdef:
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string: elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
res[k] = v res[k] = v
else: else:
raise RuntimeError("Unsupported attribute type {}.".format( raise RuntimeError("Unsupported attribute type {}.".format(self._nativedef.attrs[k]))
self._nativedef.attrs[k]))
return res return res
def _on_pyop_invocation(k_id, feed, attributes): def _on_pyop_invocation(k_id, feed, attributes):
if k_id not in Opdef._odlist: if k_id not in Opdef._odlist:
raise RuntimeError( raise RuntimeError(
"Unable to find function id={}. " "Unable to find function id={}. " "Did you decorate the operator with @onnx_op?.".format(k_id)
"Did you decorate the operator with @onnx_op?.".format(k_id)) )
op_ = Opdef._odlist[k_id] op_ = Opdef._odlist[k_id]
rv = op_.body(*feed, **op_.cast_attributes(attributes)) rv = op_.body(*feed, **op_.cast_attributes(attributes))
if isinstance(rv, tuple): if isinstance(rv, tuple):
@ -127,86 +127,7 @@ def _on_pyop_invocation(k_id, feed, attributes):
res = tuple(res) res = tuple(res)
else: else:
res = (rv.shape, rv.flatten().tolist()) res = (rv.shape, rv.flatten().tolist())
return (k_id, ) + res return (k_id,) + res
def _ensure_opset_domain(model):
op_domain_name = default_opset_domain()
domain_missing = True
for oi_ in model.opset_import:
if oi_.domain == op_domain_name:
domain_missing = False
if domain_missing:
model.opset_import.extend(
[helper.make_operatorsetid(op_domain_name, 1)])
return model
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
"""
Replace the existing inputs of a model with the new inputs, plus some extra nodes
:param model: The ONNX model loaded as ModelProto
:param target_input: The input name to be replaced
:param extra_nodes: The extra nodes to be added
:param new_inputs: The new input (type: ValueInfoProto) sequence
:return: The ONNX model after modification
"""
graph = model.graph
new_inputs = [n for n in graph.input if n.name !=
target_input] + new_inputs
new_nodes = list(model.graph.node) + extra_nodes
new_graph = helper.make_graph(
new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
new_model = copy.deepcopy(model)
new_model.graph.CopyFrom(new_graph)
return _ensure_opset_domain(new_model)
def hook_model_op(model, node_name, hook_func, input_types):
"""
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
:param model: The ONNX model loaded as ModelProto
:param node_name: The node name where the hook will be installed
:param hook_func: The hook function, callback on the model inference
:param input_types: The input types as a list
:return: The ONNX model with the hook installed
"""
# onnx.shape_inference is very unstable, useless.
# hkd_model = shape_inference.infer_shapes(model)
hkd_model = model
n_idx = 0
hnode, nnode = (None, None)
nodes = list(hkd_model.graph.node)
brkpt_name = node_name + '_hkd'
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
for n_ in nodes:
if n_.name == node_name:
input_names = list(n_.input)
brk_output_name = [i_ + '_hkd' for i_ in input_names]
hnode = onnx.helper.make_node(
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
nnode = n_
del nnode.input[:]
nnode.input.extend(brk_output_name)
break
n_idx += 1
if hnode is None:
raise ValueError("{} is not an operator node name".format(node_name))
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
del hkd_model.graph.node[:]
hkd_model.graph.node.extend(repacked)
Opdef.create(hook_func, op_type=optype_name,
inputs=input_types, outputs=input_types)
return _ensure_opset_domain(hkd_model)
PyCustomOpDef.install_hooker(_on_pyop_invocation) PyCustomOpDef.install_hooker(_on_pyop_invocation)

Просмотреть файл

@ -7,8 +7,9 @@
_ortapi2.py: ONNXRuntime-Extensions Python API _ortapi2.py: ONNXRuntime-Extensions Python API
""" """
import copy
import numpy as np import numpy as np
from ._ocos import default_opset_domain, get_library_path # noqa from ._ocos import default_opset_domain, get_library_path, Opdef
from ._cuops import onnx, onnx_proto, SingleOpGraph from ._cuops import onnx, onnx_proto, SingleOpGraph
_ort_check_passed = False _ort_check_passed = False
@ -25,6 +26,82 @@ if not _ort_check_passed:
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0") raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
def _ensure_opset_domain(model):
op_domain_name = default_opset_domain()
domain_missing = True
for oi_ in model.opset_import:
if oi_.domain == op_domain_name:
domain_missing = False
if domain_missing:
model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)])
return model
def hook_model_op(model, node_name, hook_func, input_types):
"""
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
:param model: The ONNX model loaded as ModelProto
:param node_name: The node name where the hook will be installed
:param hook_func: The hook function, callback on the model inference
:param input_types: The input types as a list
:return: The ONNX model with the hook installed
"""
# onnx.shape_inference is very unstable, useless.
# hkd_model = shape_inference.infer_shapes(model)
hkd_model = model
n_idx = 0
hnode, nnode = (None, None)
nodes = list(hkd_model.graph.node)
brkpt_name = node_name + "_hkd"
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
for n_ in nodes:
if n_.name == node_name:
input_names = list(n_.input)
brk_output_name = [i_ + "_hkd" for i_ in input_names]
hnode = onnx.helper.make_node(
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()
)
nnode = n_
del nnode.input[:]
nnode.input.extend(brk_output_name)
break
n_idx += 1
if hnode is None:
raise ValueError("{} is not an operator node name".format(node_name))
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :]
del hkd_model.graph.node[:]
hkd_model.graph.node.extend(repacked)
Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
return _ensure_opset_domain(hkd_model)
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
"""
Replace the existing inputs of a model with the new inputs, plus some extra nodes
:param model: The ONNX model loaded as ModelProto
:param target_input: The input name to be replaced
:param extra_nodes: The extra nodes to be added
:param new_inputs: The new input (type: ValueInfoProto) sequence
:return: The ONNX model after modification
"""
graph = model.graph
new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
new_nodes = list(model.graph.node) + extra_nodes
new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
new_model = copy.deepcopy(model)
new_model.graph.CopyFrom(new_graph)
return _ensure_opset_domain(new_model)
def get_opset_version_from_ort(): def get_opset_version_from_ort():
_ORT_OPSET_SUPPORT_TABLE = { _ORT_OPSET_SUPPORT_TABLE = {
"1.5": 11, "1.5": 11,
@ -37,10 +114,10 @@ def get_opset_version_from_ort():
"1.12": 17, "1.12": 17,
"1.13": 17, "1.13": 17,
"1.14": 18, "1.14": 18,
"1.15": 18 "1.15": 18,
} }
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) ort_ver_string = ".".join(_ort.__version__.split(".")[0:2])
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get) max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
if ort_ver_string > max_ver: if ort_ver_string > max_ver:
ort_ver_string = max_ver ort_ver_string = max_ver
@ -50,19 +127,18 @@ def get_opset_version_from_ort():
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1): def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
if opset_version == 0: if opset_version == 0:
opset_version = get_opset_version_from_ort() opset_version = get_opset_version_from_ort()
fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version' fn_mm = (
) else onnx.helper.make_model onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model
model = fn_mm(graph, opset_imports=[ )
onnx.helper.make_operatorsetid('ai.onnx', opset_version)]) model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)])
model.opset_import.extend( model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
[onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
return model return model
class OrtPyFunction: class OrtPyFunction:
""" """
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession, OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
standard Python function. The order of the function arguments correlates directly with standard Python function. The order of the function arguments correlates directly with
the sequence of the input/output in the ONNX graph. the sequence of the input/output in the ONNX graph.
""" """
@ -78,10 +154,10 @@ class OrtPyFunction:
self._onnx_model = None self._onnx_model = None
self.ort_session = None self.ort_session = None
self.default_inputs = {} self.default_inputs = {}
self.execution_providers = ['CPUExecutionProvider'] self.execution_providers = ["CPUExecutionProvider"]
if not cpu_only: if not cpu_only:
if _ort.get_device() == 'GPU': if _ort.get_device() == "GPU":
self.execution_providers = ['CUDAExecutionProvider'] self.execution_providers = ["CUDAExecutionProvider"]
self.extra_session_options = {} self.extra_session_options = {}
mpath = None mpath = None
if isinstance(path_or_model, str): if isinstance(path_or_model, str):
@ -99,8 +175,8 @@ class OrtPyFunction:
def add_default_input(self, **kwargs): def add_default_input(self, **kwargs):
inputs = { inputs = {
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8)
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items() for ky_, val_ in kwargs.items()
} }
self.default_inputs.update(inputs) self.default_inputs.update(inputs)
@ -124,30 +200,29 @@ class OrtPyFunction:
self._oxml = oxml self._oxml = oxml
if model_path is not None: if model_path is not None:
self.ort_session = _ort.InferenceSession( self.ort_session = _ort.InferenceSession(
model_path, self.get_ort_session_options(), model_path, self.get_ort_session_options(), self.execution_providers
self.execution_providers) )
return self return self
def _ensure_ort_session(self): def _ensure_ort_session(self):
if self.ort_session is None: if self.ort_session is None:
sess = _ort.InferenceSession( sess = _ort.InferenceSession(
self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers
self.execution_providers) )
self.ort_session = sess self.ort_session = sess
return self.ort_session return self.ort_session
@staticmethod @staticmethod
def _get_kwarg_device(kwargs): def _get_kwarg_device(kwargs):
cpuonly = kwargs.get('cpu_only', None) cpuonly = kwargs.get("cpu_only", None)
if cpuonly is not None: if cpuonly is not None:
del kwargs['cpu_only'] del kwargs["cpu_only"]
return cpuonly return cpuonly
@classmethod @classmethod
def from_customop(cls, op_type, *args, **kwargs): def from_customop(cls, op_type, *args, **kwargs):
return (cls(cpu_only=cls._get_kwarg_device(kwargs)) return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
.create_from_customop(op_type, *args, **kwargs))
@classmethod @classmethod
def from_model(cls, path_or_model, *args, **kwargs): def from_model(cls, path_or_model, *args, **kwargs):
@ -165,9 +240,9 @@ class OrtPyFunction:
x = args[idx] x = args[idx]
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
# numpy by default is int32 in some platforms, sometimes it is int64. # numpy by default is int32 in some platforms, sometimes it is int64.
feed[i_.name] = \ feed[i_.name] = (
ts_x.astype( ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x )
idx += 1 idx += 1
feed.update(kwargs) feed.update(kwargs)
@ -175,8 +250,7 @@ class OrtPyFunction:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self._ensure_ort_session() self._ensure_ort_session()
outputs = self.ort_session.run( outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
None, self._argument_map(*args, **kwargs))
return outputs[0] if len(outputs) == 1 else tuple(outputs) return outputs[0] if len(outputs) == 1 else tuple(outputs)
@ -191,8 +265,9 @@ def optimize_model(model_or_file, output_file):
sess_options = OrtPyFunction().get_ort_session_options() sess_options = OrtPyFunction().get_ort_session_options()
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = output_file sess_options.optimized_model_filepath = output_file
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str) _ort.InferenceSession(
else model_or_file.SerializeToString(), sess_options) model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options
)
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail

Просмотреть файл

@ -1,8 +1,7 @@
# include requirements.txt so pip has context to avoid installing incompatible dependencies
-r requirements.txt
pytest pytest
# multiple versions of onnxruntime are supported, but only one can be installed at a time onnx >= 1.9.0
protobuf < 4.0.0 protobuf < 4.0.0
# multiple versions of onnxruntime are supported, but only one can be installed at a time
onnxruntime >=1.12.0 onnxruntime >=1.12.0
transformers >=4.9.2 transformers >=4.9.2
tensorflow_text >=2.5.0;python_version < '3.11' tensorflow_text >=2.5.0;python_version < '3.11'

Просмотреть файл

@ -1 +0,0 @@
onnx>=1.9.0

Просмотреть файл

@ -9,14 +9,13 @@ import sys
import pathlib import pathlib
import setuptools import setuptools
from textwrap import dedent
from setuptools import setup, find_packages from setuptools import setup, find_packages
TOP_DIR = os.path.dirname(__file__) or os.getcwd() TOP_DIR = os.path.dirname(__file__) or os.getcwd()
PACKAGE_NAME = 'onnxruntime_extensions' PACKAGE_NAME = "onnxruntime_extensions"
# setup.py cannot be debugged in pip command line, so the command classes are refactored into another file # setup.py cannot be debugged in pip command line, so the command classes are refactored into another file
cmds_dir = pathlib.Path(TOP_DIR) / '.pyproject' cmds_dir = pathlib.Path(TOP_DIR) / ".pyproject"
sys.path.append(str(cmds_dir)) sys.path.append(str(cmds_dir))
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import cmdclass as _cmds # noqa: E402 import cmdclass as _cmds # noqa: E402
@ -24,62 +23,50 @@ import cmdclass as _cmds # noqa: E402
_cmds.prepare_env(TOP_DIR) _cmds.prepare_env(TOP_DIR)
def read_requirements():
with open(os.path.join(TOP_DIR, "requirements.txt"), "r", encoding="utf-8") as f:
requirements = [_ for _ in [dedent(_) for _ in f.readlines()] if _ is not None]
return requirements
# read version from the package file. # read version from the package file.
def read_version(): def read_version():
version_str = '1.0.0' version_str = "0.1.0"
with (open(os.path.join(TOP_DIR, 'version.txt'), "r")) as f: with open(os.path.join(TOP_DIR, "version.txt"), "r") as f:
version_str = f.readline().strip() version_str = f.readline().strip()
# special handling for Onebranch building # special handling for Onebranch building
if os.getenv('BUILD_SOURCEBRANCHNAME', "").startswith('rel-'): if os.getenv("BUILD_SOURCEBRANCHNAME", "").startswith("rel-"):
return version_str return version_str
# is it a dev build or release? # is it a dev build or release?
rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir( rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(os.path.join(TOP_DIR, ".git")) else (True, None)
os.path.join(TOP_DIR, '.git')) else (True, None)
if rel_br: if rel_br:
return version_str return version_str
build_id = os.getenv('BUILD_BUILDID', None) build_id = os.getenv("BUILD_BUILDID", None)
if build_id is not None: if build_id is not None:
version_str += '.{}'.format(build_id) version_str += ".{}".format(build_id)
else: else:
version_str += '+' + cid[:7] version_str += "+" + cid[:7]
return version_str return version_str
def write_py_version(ext_version): def write_py_version(ext_version):
text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", '__version__ = "{}"\n'.format(ext_version)]
"__version__ = \"{}\"\n".format(ext_version)] with open(os.path.join(TOP_DIR, "onnxruntime_extensions/_version.py"), "w") as _fver:
with (open(os.path.join(TOP_DIR, 'onnxruntime_extensions/_version.py'), "w")) as _fver:
_fver.writelines(text) _fver.writelines(text)
ext_modules = [ ext_modules = [setuptools.extension.Extension(name=str("onnxruntime_extensions._extensions_pydll"), sources=[])]
setuptools.extension.Extension(
name=str('onnxruntime_extensions._extensions_pydll'),
sources=[])
]
packages = find_packages() packages = find_packages()
package_dir = {k: os.path.join('.', k.replace(".", "/")) for k in packages} package_dir = {k: os.path.join(".", k.replace(".", "/")) for k in packages}
package_data = { package_data = {
"onnxruntime_extensions": ["*.so", "*.pyd"], "onnxruntime_extensions": ["*.so", "*.pyd"],
} }
long_description = '' long_description = ""
with open(os.path.join(TOP_DIR, "README.md"), 'r', encoding="utf-8") as _f: with open(os.path.join(TOP_DIR, "README.md"), "r", encoding="utf-8") as _f:
long_description += _f.read() long_description += _f.read()
start_pos = long_description.find('# Introduction') start_pos = long_description.find("# Introduction")
start_pos = 0 if start_pos < 0 else start_pos start_pos = 0 if start_pos < 0 else start_pos
end_pos = long_description.find('# Contributing') end_pos = long_description.find("# Contributing")
long_description = long_description[start_pos:end_pos] long_description = long_description[start_pos:end_pos]
ortx_version = read_version() ortx_version = read_version()
write_py_version(ortx_version) write_py_version(ortx_version)
@ -92,25 +79,25 @@ setup(
package_data=package_data, package_data=package_data,
description="ONNXRuntime Extensions", description="ONNXRuntime Extensions",
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
license='MIT License', license="MIT License",
author='Microsoft Corporation', author="Microsoft Corporation",
author_email='onnxruntime@microsoft.com', author_email="onnxruntime@microsoft.com",
url='https://github.com/microsoft/onnxruntime-extensions', url="https://github.com/microsoft/onnxruntime-extensions",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=_cmds.ortx_cmdclass, cmdclass=_cmds.ortx_cmdclass,
include_package_data=True, include_package_data=True,
install_requires=read_requirements(), install_requires=[],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'Environment :: Console', "Environment :: Console",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'Operating System :: MacOS :: MacOS X', "Operating System :: MacOS :: MacOS X",
'Operating System :: Microsoft :: Windows', "Operating System :: Microsoft :: Windows",
'Operating System :: POSIX :: Linux', "Operating System :: POSIX :: Linux",
"Programming Language :: C++", "Programming Language :: C++",
'Programming Language :: Python', "Programming Language :: Python",
"Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: CPython",
'License :: OSI Approved :: MIT License' "License :: OSI Approved :: MIT License",
] ],
) )