make onnx package to be optional. (#653)

* putting onnx package to be optional

* update the ci.yml

* add more message of missing ONNX package
This commit is contained in:
Wenbing Li 2024-02-15 14:09:04 -08:00 коммит произвёл GitHub
Родитель fc275e623f
Коммит b045e66396
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 225 добавлений и 212 удалений

Просмотреть файл

@ -89,8 +89,7 @@ stages:
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt
displayName: Install requirements.txt
displayName: Install requirements
- script: |
CPU_NUMBER=8 python -m pip install .
@ -283,8 +282,7 @@ stages:
python -m pip install --upgrade setuptools
python -m pip install --upgrade wheel
python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt
displayName: Install requirements.txt
displayName: Install requirements
- script: |
python -c "import onnxruntime;print(onnxruntime.__version__)"
@ -419,7 +417,6 @@ stages:
call activate pyenv
python -m pip install --upgrade pip
python -m pip install onnxruntime==$(ort.version)
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
displayName: Install requirements{-dev}.txt and cmake python modules
@ -653,7 +650,6 @@ stages:
python3 -m pip install --upgrade pip; \
python3 -m pip install --upgrade setuptools; \
python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \
python3 -m pip install -r requirements.txt; \
python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \
python3 -m pip install $(TORCH_VERSION) ; \
python3 -m pip install -r requirements-dev.txt; \

Просмотреть файл

@ -31,12 +31,15 @@ python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.gi
## Usage
## 1. Generate the pre-/post- processing ONNX model
With onnxruntime-extensions Python package, you can easily get the ONNX processing graph by converting them from Huggingface transformer data processing classes, check the following API for details.
## 1. Generation of Pre-/Post-Processing ONNX Model
The `onnxruntime-extensions` Python package provides a convenient way to generate the ONNX processing graph. This can be achieved by converting the Huggingface transformer data processing classes into the desired format. For more detailed information, please refer to the API below:
```python
help(onnxruntime_extensions.gen_processing_models)
```
### NOTE: These data processing model can be merged into other model [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
### NOTE:
The generation of model processing requires the **ONNX** package to be installed. The data processing models generated in this manner can be merged with other models using the [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
## 2. Using Extensions for ONNX Runtime inference
### Python

Просмотреть файл

@ -10,37 +10,70 @@ This enables more flexibility and control over model execution, thus expanding t
__author__ = "Microsoft"
__all__ = [
'gen_processing_models',
'ort_inference',
'get_library_path',
'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp',
'enable_py_op',
'expand_onnx_inputs',
'hook_model_op',
'default_opset_domain',
'OrtPyFunction', 'PyOrtFunction',
'optimize_model',
'make_onnx_model',
'ONNXRuntimeError',
'hash_64',
'__version__',
]
from ._version import __version__
from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef
from ._ocos import hash_64
from ._ocos import enable_py_op
from ._ocos import expand_onnx_inputs
from ._ocos import hook_model_op
from ._ocos import default_opset_domain
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException
from .cvt import gen_processing_models
_lib_only = False
try:
import onnx # noqa
import onnxruntime # noqa
except ImportError:
_lib_only = True
pass
_offline_api = [
"gen_processing_models",
"ort_inference",
"OrtPyFunction",
"PyOrtFunction",
"optimize_model",
"make_onnx_model",
"ONNXRuntimeError",
]
__all__ = [
"get_library_path",
"Opdef",
"onnx_op",
"PyCustomOpDef",
"PyOp",
"enable_py_op",
"expand_onnx_inputs",
"hook_model_op",
"default_opset_domain",
"hash_64",
"__version__",
]
# rename the implementation with a more formal name
onnx_op = Opdef.declare
PyOp = PyCustomOpDef
if _lib_only:
def _unimplemented(*args, **kwargs):
raise NotImplementedError("ONNX or ONNX Runtime is not installed")
gen_processing_models = _unimplemented
OrtPyFunction = _unimplemented
ort_inference = _unimplemented
else:
__all__ += _offline_api
from ._cuops import * # noqa
from ._ortapi2 import hook_model_op
from ._ortapi2 import expand_onnx_inputs
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException # noqa
from .cvt import gen_processing_models

Просмотреть файл

@ -7,34 +7,36 @@ _ocos.py: PythonOp implementation
"""
import os
import sys
import copy
import glob
import onnx
from onnx import helper
def _search_cuda_dir():
paths = os.getenv('PATH', '').split(os.pathsep)
paths = os.getenv("PATH", "").split(os.pathsep)
for path in paths:
for filename in glob.glob(os.path.join(path, 'cudart64*.dll')):
for filename in glob.glob(os.path.join(path, "cudart64*.dll")):
return os.path.dirname(filename)
return None
if sys.platform == 'win32':
if sys.platform == "win32":
from . import _version # noqa: E402
if hasattr(_version, 'cuda'):
if hasattr(_version, "cuda"):
cuda_path = _search_cuda_dir()
if cuda_path is None:
raise RuntimeError(
"Cannot locate CUDA directory in the environment variable for GPU package")
raise RuntimeError("Cannot locate CUDA directory in the environment variable for GPU package")
os.add_dll_directory(cuda_path)
from ._extensions_pydll import ( # noqa
PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain)
PyCustomOpDef,
enable_py_op,
add_custom_op,
hash_64,
default_opset_domain,
)
def get_library_path():
@ -42,12 +44,11 @@ def get_library_path():
The custom operator library binary path
:return: A string of this library path.
"""
mod = sys.modules['onnxruntime_extensions._extensions_pydll']
mod = sys.modules["onnxruntime_extensions._extensions_pydll"]
return mod.__file__
class Opdef:
_odlist = {}
def __init__(self, op_type, func):
@ -57,14 +58,14 @@ class Opdef:
@staticmethod
def declare(*args, **kwargs):
if len(args) > 0 and hasattr(args[0], '__call__'):
if len(args) > 0 and hasattr(args[0], "__call__"):
raise RuntimeError("Unexpected arguments {}.".format(args))
# return Opdef._create(args[0])
return lambda f: Opdef.create(f, *args, **kwargs)
@staticmethod
def create(func, *args, **kwargs):
name = kwargs.get('op_type', None)
name = kwargs.get("op_type", None)
op_type = name or func.__name__
opdef = Opdef(op_type, func)
od_id = id(opdef)
@ -76,15 +77,15 @@ class Opdef:
opdef._nativedef.op_type = op_type
opdef._nativedef.obj_id = od_id
inputs = kwargs.get('inputs', None)
inputs = kwargs.get("inputs", None)
if inputs is None:
inputs = [PyCustomOpDef.dt_float]
opdef._nativedef.input_types = inputs
outputs = kwargs.get('outputs', None)
outputs = kwargs.get("outputs", None)
if outputs is None:
outputs = [PyCustomOpDef.dt_float]
opdef._nativedef.output_types = outputs
attrs = kwargs.get('attrs', None)
attrs = kwargs.get("attrs", None)
if attrs is None:
attrs = {}
elif isinstance(attrs, (list, tuple)):
@ -106,16 +107,15 @@ class Opdef:
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
res[k] = v
else:
raise RuntimeError("Unsupported attribute type {}.".format(
self._nativedef.attrs[k]))
raise RuntimeError("Unsupported attribute type {}.".format(self._nativedef.attrs[k]))
return res
def _on_pyop_invocation(k_id, feed, attributes):
if k_id not in Opdef._odlist:
raise RuntimeError(
"Unable to find function id={}. "
"Did you decorate the operator with @onnx_op?.".format(k_id))
"Unable to find function id={}. " "Did you decorate the operator with @onnx_op?.".format(k_id)
)
op_ = Opdef._odlist[k_id]
rv = op_.body(*feed, **op_.cast_attributes(attributes))
if isinstance(rv, tuple):
@ -127,86 +127,7 @@ def _on_pyop_invocation(k_id, feed, attributes):
res = tuple(res)
else:
res = (rv.shape, rv.flatten().tolist())
return (k_id, ) + res
def _ensure_opset_domain(model):
op_domain_name = default_opset_domain()
domain_missing = True
for oi_ in model.opset_import:
if oi_.domain == op_domain_name:
domain_missing = False
if domain_missing:
model.opset_import.extend(
[helper.make_operatorsetid(op_domain_name, 1)])
return model
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
"""
Replace the existing inputs of a model with the new inputs, plus some extra nodes
:param model: The ONNX model loaded as ModelProto
:param target_input: The input name to be replaced
:param extra_nodes: The extra nodes to be added
:param new_inputs: The new input (type: ValueInfoProto) sequence
:return: The ONNX model after modification
"""
graph = model.graph
new_inputs = [n for n in graph.input if n.name !=
target_input] + new_inputs
new_nodes = list(model.graph.node) + extra_nodes
new_graph = helper.make_graph(
new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
new_model = copy.deepcopy(model)
new_model.graph.CopyFrom(new_graph)
return _ensure_opset_domain(new_model)
def hook_model_op(model, node_name, hook_func, input_types):
"""
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
:param model: The ONNX model loaded as ModelProto
:param node_name: The node name where the hook will be installed
:param hook_func: The hook function, callback on the model inference
:param input_types: The input types as a list
:return: The ONNX model with the hook installed
"""
# onnx.shape_inference is very unstable, useless.
# hkd_model = shape_inference.infer_shapes(model)
hkd_model = model
n_idx = 0
hnode, nnode = (None, None)
nodes = list(hkd_model.graph.node)
brkpt_name = node_name + '_hkd'
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
for n_ in nodes:
if n_.name == node_name:
input_names = list(n_.input)
brk_output_name = [i_ + '_hkd' for i_ in input_names]
hnode = onnx.helper.make_node(
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
nnode = n_
del nnode.input[:]
nnode.input.extend(brk_output_name)
break
n_idx += 1
if hnode is None:
raise ValueError("{} is not an operator node name".format(node_name))
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
del hkd_model.graph.node[:]
hkd_model.graph.node.extend(repacked)
Opdef.create(hook_func, op_type=optype_name,
inputs=input_types, outputs=input_types)
return _ensure_opset_domain(hkd_model)
return (k_id,) + res
PyCustomOpDef.install_hooker(_on_pyop_invocation)

Просмотреть файл

@ -7,8 +7,9 @@
_ortapi2.py: ONNXRuntime-Extensions Python API
"""
import copy
import numpy as np
from ._ocos import default_opset_domain, get_library_path # noqa
from ._ocos import default_opset_domain, get_library_path, Opdef
from ._cuops import onnx, onnx_proto, SingleOpGraph
_ort_check_passed = False
@ -25,6 +26,82 @@ if not _ort_check_passed:
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
def _ensure_opset_domain(model):
op_domain_name = default_opset_domain()
domain_missing = True
for oi_ in model.opset_import:
if oi_.domain == op_domain_name:
domain_missing = False
if domain_missing:
model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)])
return model
def hook_model_op(model, node_name, hook_func, input_types):
"""
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
:param model: The ONNX model loaded as ModelProto
:param node_name: The node name where the hook will be installed
:param hook_func: The hook function, callback on the model inference
:param input_types: The input types as a list
:return: The ONNX model with the hook installed
"""
# onnx.shape_inference is very unstable, useless.
# hkd_model = shape_inference.infer_shapes(model)
hkd_model = model
n_idx = 0
hnode, nnode = (None, None)
nodes = list(hkd_model.graph.node)
brkpt_name = node_name + "_hkd"
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
for n_ in nodes:
if n_.name == node_name:
input_names = list(n_.input)
brk_output_name = [i_ + "_hkd" for i_ in input_names]
hnode = onnx.helper.make_node(
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()
)
nnode = n_
del nnode.input[:]
nnode.input.extend(brk_output_name)
break
n_idx += 1
if hnode is None:
raise ValueError("{} is not an operator node name".format(node_name))
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :]
del hkd_model.graph.node[:]
hkd_model.graph.node.extend(repacked)
Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
return _ensure_opset_domain(hkd_model)
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
"""
Replace the existing inputs of a model with the new inputs, plus some extra nodes
:param model: The ONNX model loaded as ModelProto
:param target_input: The input name to be replaced
:param extra_nodes: The extra nodes to be added
:param new_inputs: The new input (type: ValueInfoProto) sequence
:return: The ONNX model after modification
"""
graph = model.graph
new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
new_nodes = list(model.graph.node) + extra_nodes
new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
new_model = copy.deepcopy(model)
new_model.graph.CopyFrom(new_graph)
return _ensure_opset_domain(new_model)
def get_opset_version_from_ort():
_ORT_OPSET_SUPPORT_TABLE = {
"1.5": 11,
@ -37,10 +114,10 @@ def get_opset_version_from_ort():
"1.12": 17,
"1.13": 17,
"1.14": 18,
"1.15": 18
"1.15": 18,
}
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
ort_ver_string = ".".join(_ort.__version__.split(".")[0:2])
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
if ort_ver_string > max_ver:
ort_ver_string = max_ver
@ -50,19 +127,18 @@ def get_opset_version_from_ort():
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
if opset_version == 0:
opset_version = get_opset_version_from_ort()
fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
) else onnx.helper.make_model
model = fn_mm(graph, opset_imports=[
onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
model.opset_import.extend(
[onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
fn_mm = (
onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model
)
model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)])
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
return model
class OrtPyFunction:
"""
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
standard Python function. The order of the function arguments correlates directly with
the sequence of the input/output in the ONNX graph.
"""
@ -78,10 +154,10 @@ class OrtPyFunction:
self._onnx_model = None
self.ort_session = None
self.default_inputs = {}
self.execution_providers = ['CPUExecutionProvider']
self.execution_providers = ["CPUExecutionProvider"]
if not cpu_only:
if _ort.get_device() == 'GPU':
self.execution_providers = ['CUDAExecutionProvider']
if _ort.get_device() == "GPU":
self.execution_providers = ["CUDAExecutionProvider"]
self.extra_session_options = {}
mpath = None
if isinstance(path_or_model, str):
@ -99,8 +175,8 @@ class OrtPyFunction:
def add_default_input(self, **kwargs):
inputs = {
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8)
for ky_, val_ in kwargs.items()
}
self.default_inputs.update(inputs)
@ -124,30 +200,29 @@ class OrtPyFunction:
self._oxml = oxml
if model_path is not None:
self.ort_session = _ort.InferenceSession(
model_path, self.get_ort_session_options(),
self.execution_providers)
model_path, self.get_ort_session_options(), self.execution_providers
)
return self
def _ensure_ort_session(self):
if self.ort_session is None:
sess = _ort.InferenceSession(
self.onnx_model.SerializeToString(), self.get_ort_session_options(),
self.execution_providers)
self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers
)
self.ort_session = sess
return self.ort_session
@staticmethod
def _get_kwarg_device(kwargs):
cpuonly = kwargs.get('cpu_only', None)
cpuonly = kwargs.get("cpu_only", None)
if cpuonly is not None:
del kwargs['cpu_only']
del kwargs["cpu_only"]
return cpuonly
@classmethod
def from_customop(cls, op_type, *args, **kwargs):
return (cls(cpu_only=cls._get_kwarg_device(kwargs))
.create_from_customop(op_type, *args, **kwargs))
return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
@classmethod
def from_model(cls, path_or_model, *args, **kwargs):
@ -165,9 +240,9 @@ class OrtPyFunction:
x = args[idx]
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
# numpy by default is int32 in some platforms, sometimes it is int64.
feed[i_.name] = \
ts_x.astype(
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
feed[i_.name] = (
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
)
idx += 1
feed.update(kwargs)
@ -175,8 +250,7 @@ class OrtPyFunction:
def __call__(self, *args, **kwargs):
self._ensure_ort_session()
outputs = self.ort_session.run(
None, self._argument_map(*args, **kwargs))
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
return outputs[0] if len(outputs) == 1 else tuple(outputs)
@ -191,8 +265,9 @@ def optimize_model(model_or_file, output_file):
sess_options = OrtPyFunction().get_ort_session_options()
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = output_file
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
else model_or_file.SerializeToString(), sess_options)
_ort.InferenceSession(
model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options
)
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail

Просмотреть файл

@ -1,8 +1,7 @@
# include requirements.txt so pip has context to avoid installing incompatible dependencies
-r requirements.txt
pytest
# multiple versions of onnxruntime are supported, but only one can be installed at a time
onnx >= 1.9.0
protobuf < 4.0.0
# multiple versions of onnxruntime are supported, but only one can be installed at a time
onnxruntime >=1.12.0
transformers >=4.9.2
tensorflow_text >=2.5.0;python_version < '3.11'

Просмотреть файл

@ -1 +0,0 @@
onnx>=1.9.0

Просмотреть файл

@ -9,14 +9,13 @@ import sys
import pathlib
import setuptools
from textwrap import dedent
from setuptools import setup, find_packages
TOP_DIR = os.path.dirname(__file__) or os.getcwd()
PACKAGE_NAME = 'onnxruntime_extensions'
PACKAGE_NAME = "onnxruntime_extensions"
# setup.py cannot be debugged in pip command line, so the command classes are refactored into another file
cmds_dir = pathlib.Path(TOP_DIR) / '.pyproject'
cmds_dir = pathlib.Path(TOP_DIR) / ".pyproject"
sys.path.append(str(cmds_dir))
# noinspection PyUnresolvedReferences
import cmdclass as _cmds # noqa: E402
@ -24,62 +23,50 @@ import cmdclass as _cmds # noqa: E402
_cmds.prepare_env(TOP_DIR)
def read_requirements():
with open(os.path.join(TOP_DIR, "requirements.txt"), "r", encoding="utf-8") as f:
requirements = [_ for _ in [dedent(_) for _ in f.readlines()] if _ is not None]
return requirements
# read version from the package file.
def read_version():
version_str = '1.0.0'
with (open(os.path.join(TOP_DIR, 'version.txt'), "r")) as f:
version_str = "0.1.0"
with open(os.path.join(TOP_DIR, "version.txt"), "r") as f:
version_str = f.readline().strip()
# special handling for Onebranch building
if os.getenv('BUILD_SOURCEBRANCHNAME', "").startswith('rel-'):
if os.getenv("BUILD_SOURCEBRANCHNAME", "").startswith("rel-"):
return version_str
# is it a dev build or release?
rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(
os.path.join(TOP_DIR, '.git')) else (True, None)
rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(os.path.join(TOP_DIR, ".git")) else (True, None)
if rel_br:
return version_str
build_id = os.getenv('BUILD_BUILDID', None)
build_id = os.getenv("BUILD_BUILDID", None)
if build_id is not None:
version_str += '.{}'.format(build_id)
version_str += ".{}".format(build_id)
else:
version_str += '+' + cid[:7]
version_str += "+" + cid[:7]
return version_str
def write_py_version(ext_version):
text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n",
"__version__ = \"{}\"\n".format(ext_version)]
with (open(os.path.join(TOP_DIR, 'onnxruntime_extensions/_version.py'), "w")) as _fver:
text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", '__version__ = "{}"\n'.format(ext_version)]
with open(os.path.join(TOP_DIR, "onnxruntime_extensions/_version.py"), "w") as _fver:
_fver.writelines(text)
ext_modules = [
setuptools.extension.Extension(
name=str('onnxruntime_extensions._extensions_pydll'),
sources=[])
]
ext_modules = [setuptools.extension.Extension(name=str("onnxruntime_extensions._extensions_pydll"), sources=[])]
packages = find_packages()
package_dir = {k: os.path.join('.', k.replace(".", "/")) for k in packages}
package_dir = {k: os.path.join(".", k.replace(".", "/")) for k in packages}
package_data = {
"onnxruntime_extensions": ["*.so", "*.pyd"],
}
long_description = ''
with open(os.path.join(TOP_DIR, "README.md"), 'r', encoding="utf-8") as _f:
long_description = ""
with open(os.path.join(TOP_DIR, "README.md"), "r", encoding="utf-8") as _f:
long_description += _f.read()
start_pos = long_description.find('# Introduction')
start_pos = long_description.find("# Introduction")
start_pos = 0 if start_pos < 0 else start_pos
end_pos = long_description.find('# Contributing')
end_pos = long_description.find("# Contributing")
long_description = long_description[start_pos:end_pos]
ortx_version = read_version()
write_py_version(ortx_version)
@ -92,25 +79,25 @@ setup(
package_data=package_data,
description="ONNXRuntime Extensions",
long_description=long_description,
long_description_content_type='text/markdown',
license='MIT License',
author='Microsoft Corporation',
author_email='onnxruntime@microsoft.com',
url='https://github.com/microsoft/onnxruntime-extensions',
long_description_content_type="text/markdown",
license="MIT License",
author="Microsoft Corporation",
author_email="onnxruntime@microsoft.com",
url="https://github.com/microsoft/onnxruntime-extensions",
ext_modules=ext_modules,
cmdclass=_cmds.ortx_cmdclass,
include_package_data=True,
install_requires=read_requirements(),
install_requires=[],
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Console',
'Intended Audience :: Developers',
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Developers",
"Operating System :: MacOS :: MacOS X",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Programming Language :: C++",
'Programming Language :: Python',
"Programming Language :: Python",
"Programming Language :: Python :: Implementation :: CPython",
'License :: OSI Approved :: MIT License'
]
"License :: OSI Approved :: MIT License",
],
)