Compatible with onnxruntime-gpu package (#410)

* be compatible without onnxruntime-gpu version

* some fixing
This commit is contained in:
Wenbing Li 2023-04-26 17:17:23 -07:00 коммит произвёл GitHub
Родитель 4f481d23ac
Коммит 0f45fef2d9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 51 добавлений и 19 удалений

Просмотреть файл

@ -4,11 +4,12 @@
###############################################################################
"""
The entry point to onnxruntime custom op library
The entry point to onnxruntime-extensions package.
"""
__author__ = "Microsoft"
from ._version import __version__
from ._ocos import get_library_path # noqa
from ._ocos import Opdef, PyCustomOpDef # noqa
@ -24,4 +25,3 @@ from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntim
onnx_op = Opdef.declare
PyOp = PyCustomOpDef

Просмотреть файл

@ -403,6 +403,7 @@ class AudioDecoder(CustomOp):
cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
]
class StftNorm(CustomOp):
@classmethod
def get_inputs(cls):

Просмотреть файл

@ -4,9 +4,20 @@
###############################################################################
import numpy as np
import onnxruntime as _ort
from ._ocos import default_opset_domain, get_library_path # noqa
from ._cuops import * # noqa
from ._cuops import onnx, onnx_proto, CustomOpConverter, SingleOpGraph
_ort_check_passed = False
try:
from packaging import version as _ver
import onnxruntime as _ort
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
_ort_check_passed = True
except ImportError:
pass
if not _ort_check_passed:
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
def get_opset_version_from_ort():
@ -52,10 +63,14 @@ class OrtPyFunction:
so.register_custom_ops_library(get_library_path())
return so
def __init__(self):
def __init__(self, cpu_only=None):
self._onnx_model = None
self.ort_session = None
self.default_inputs = {}
self.execution_providers = ['CPUExecutionProvider']
if not cpu_only:
if _ort.get_device() == 'GPU':
self.execution_providers = ['CUDAExecutionProvider']
def create_from_customop(self, op_type, *args, **kwargs):
cvt = kwargs.get('cvt', None)
@ -98,20 +113,29 @@ class OrtPyFunction:
self._oxml = oxml
if model_path is not None:
self.ort_session = _ort.InferenceSession(
model_path, self.get_ort_session_options())
model_path, self.get_ort_session_options(),
self.execution_providers)
return self
def _ensure_ort_session(self):
if self.ort_session is None:
sess = _ort.InferenceSession(
self.onnx_model.SerializeToString(), self.get_ort_session_options())
self.onnx_model.SerializeToString(), self.get_ort_session_options(),
self.execution_providers)
self.ort_session = sess
return self.ort_session
@staticmethod
def _get_kwarg_device(kwargs):
cpuonly = kwargs.get('cpu_only', None)
if cpuonly is not None:
del kwargs['cpu_only']
return cpuonly
@classmethod
def from_customop(cls, op_type, *args, **kwargs):
return cls().create_from_customop(op_type, *args, **kwargs)
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
@classmethod
def from_model(cls, path_or_model, *args, **kwargs):
@ -121,7 +145,7 @@ class OrtPyFunction:
mpath = path_or_model
else:
oxml = path_or_model
return cls()._bind(oxml, mpath)
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
def _argument_map(self, *args, **kwargs):
idx = 0

Просмотреть файл

@ -23,14 +23,18 @@ class PyCustomOpDef:
dt_bfloat16: int = ...
...
def enable_py_op(enabled: bool) -> bool:
...
def add_custom_op(opdef: PyCustomOpDef) -> None:
...
def hash_64(s: str, num_buckets: int, fast: int) -> int:
...
def default_opset_domain() -> str:
...

Просмотреть файл

@ -1,7 +1,8 @@
import json
import pathlib
from ._onnx_ops import make_model_ex
from .._ortapi2 import SingleOpGraph, default_opset_domain, GPT2Tokenizer, VectorToString
from .._cuops import SingleOpGraph, GPT2Tokenizer, VectorToString
from .._ortapi2 import default_opset_domain
def is_path(name_or_buffer):

Просмотреть файл

@ -5,6 +5,7 @@ import inspect
import numpy as np
# some util function for testing and tools
def get_test_data_file(*sub_dirs):
case_file = inspect.currentframe().f_back.f_code.co_filename
@ -18,7 +19,7 @@ def read_file(path, mode='r'):
def mel_filterbank(
n_fft:int, n_mels: int=80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns
and it is Slaney normalized mel-scale filterbank.

Просмотреть файл

@ -1,5 +1,7 @@
# include requirements.txt so pip has context to avoid installing incompatible dependencies
-r requirements.txt
pytest
# multiple versions of onnxruntime are supported, but only one can be installed at a time
onnxruntime >=1.10.0
transformers >= 4.9.2,<=4.24.0
tensorflow_text >=2.5.0

Просмотреть файл

@ -1,2 +1 @@
onnx>=1.9.0
onnxruntime>=1.10.0

Просмотреть файл

@ -99,7 +99,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_model(_create_test_model())
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0]
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
@ -109,7 +109,7 @@ class TestAudio(unittest.TestCase):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = PyOrtFunction.from_customop("StftNorm")
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0]
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)

Просмотреть файл

@ -179,7 +179,7 @@ def preprocessing(audio_data):
return result
def merge_models(core: str, output_model:str, audio_data):
def merge_models(core: str, output_model: str, audio_data):
m_pre = onnx.load_model("whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
m_core = onnx.load_model(core)
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])