Compatible with onnxruntime-gpu package (#410)
* be compatible without onnxruntime-gpu version * some fixing
This commit is contained in:
Родитель
4f481d23ac
Коммит
0f45fef2d9
|
@ -4,11 +4,12 @@
|
|||
###############################################################################
|
||||
|
||||
"""
|
||||
The entry point to onnxruntime custom op library
|
||||
The entry point to onnxruntime-extensions package.
|
||||
"""
|
||||
|
||||
__author__ = "Microsoft"
|
||||
|
||||
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
|
@ -18,10 +19,9 @@ from ._ocos import expand_onnx_inputs # noqa
|
|||
from ._ocos import hook_model_op # noqa
|
||||
from ._ocos import default_opset_domain # noqa
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
||||
|
|
|
@ -403,6 +403,7 @@ class AudioDecoder(CustomOp):
|
|||
cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
|
||||
]
|
||||
|
||||
|
||||
class StftNorm(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
|
|
|
@ -4,9 +4,20 @@
|
|||
###############################################################################
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||
from ._cuops import * # noqa
|
||||
from ._cuops import onnx, onnx_proto, CustomOpConverter, SingleOpGraph
|
||||
|
||||
_ort_check_passed = False
|
||||
try:
|
||||
from packaging import version as _ver
|
||||
import onnxruntime as _ort
|
||||
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
|
||||
_ort_check_passed = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not _ort_check_passed:
|
||||
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
|
||||
|
||||
|
||||
def get_opset_version_from_ort():
|
||||
|
@ -52,10 +63,14 @@ class OrtPyFunction:
|
|||
so.register_custom_ops_library(get_library_path())
|
||||
return so
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, cpu_only=None):
|
||||
self._onnx_model = None
|
||||
self.ort_session = None
|
||||
self.default_inputs = {}
|
||||
self.execution_providers = ['CPUExecutionProvider']
|
||||
if not cpu_only:
|
||||
if _ort.get_device() == 'GPU':
|
||||
self.execution_providers = ['CUDAExecutionProvider']
|
||||
|
||||
def create_from_customop(self, op_type, *args, **kwargs):
|
||||
cvt = kwargs.get('cvt', None)
|
||||
|
@ -98,20 +113,29 @@ class OrtPyFunction:
|
|||
self._oxml = oxml
|
||||
if model_path is not None:
|
||||
self.ort_session = _ort.InferenceSession(
|
||||
model_path, self.get_ort_session_options())
|
||||
model_path, self.get_ort_session_options(),
|
||||
self.execution_providers)
|
||||
return self
|
||||
|
||||
def _ensure_ort_session(self):
|
||||
if self.ort_session is None:
|
||||
sess = _ort.InferenceSession(
|
||||
self.onnx_model.SerializeToString(), self.get_ort_session_options())
|
||||
self.onnx_model.SerializeToString(), self.get_ort_session_options(),
|
||||
self.execution_providers)
|
||||
self.ort_session = sess
|
||||
|
||||
return self.ort_session
|
||||
|
||||
@staticmethod
|
||||
def _get_kwarg_device(kwargs):
|
||||
cpuonly = kwargs.get('cpu_only', None)
|
||||
if cpuonly is not None:
|
||||
del kwargs['cpu_only']
|
||||
return cpuonly
|
||||
|
||||
@classmethod
|
||||
def from_customop(cls, op_type, *args, **kwargs):
|
||||
return cls().create_from_customop(op_type, *args, **kwargs)
|
||||
return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, path_or_model, *args, **kwargs):
|
||||
|
@ -121,7 +145,7 @@ class OrtPyFunction:
|
|||
mpath = path_or_model
|
||||
else:
|
||||
oxml = path_or_model
|
||||
return cls()._bind(oxml, mpath)
|
||||
return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
|
||||
|
||||
def _argument_map(self, *args, **kwargs):
|
||||
idx = 0
|
||||
|
|
|
@ -23,14 +23,18 @@ class PyCustomOpDef:
|
|||
dt_bfloat16: int = ...
|
||||
...
|
||||
|
||||
|
||||
def enable_py_op(enabled: bool) -> bool:
|
||||
...
|
||||
|
||||
|
||||
def add_custom_op(opdef: PyCustomOpDef) -> None:
|
||||
...
|
||||
|
||||
|
||||
def hash_64(s: str, num_buckets: int, fast: int) -> int:
|
||||
...
|
||||
|
||||
|
||||
def default_opset_domain() -> str:
|
||||
...
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import pathlib
|
||||
from ._onnx_ops import make_model_ex
|
||||
from .._ortapi2 import SingleOpGraph, default_opset_domain, GPT2Tokenizer, VectorToString
|
||||
from .._cuops import SingleOpGraph, GPT2Tokenizer, VectorToString
|
||||
from .._ortapi2 import default_opset_domain
|
||||
|
||||
|
||||
def is_path(name_or_buffer):
|
||||
|
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# some util function for testing and tools
|
||||
def get_test_data_file(*sub_dirs):
|
||||
case_file = inspect.currentframe().f_back.f_code.co_filename
|
||||
|
@ -18,7 +19,7 @@ def read_file(path, mode='r'):
|
|||
|
||||
|
||||
def mel_filterbank(
|
||||
n_fft:int, n_mels: int=80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
|
||||
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
|
||||
"""
|
||||
Compute a Mel-filterbank. The filters are stored in the rows, the columns
|
||||
and it is Slaney normalized mel-scale filterbank.
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# include requirements.txt so pip has context to avoid installing incompatible dependencies
|
||||
-r requirements.txt
|
||||
pytest
|
||||
# multiple versions of onnxruntime are supported, but only one can be installed at a time
|
||||
onnxruntime >=1.10.0
|
||||
transformers >= 4.9.2,<=4.24.0
|
||||
tensorflow_text >=2.5.0
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
onnx>=1.9.0
|
||||
onnxruntime>=1.10.0
|
||||
|
|
|
@ -99,7 +99,7 @@ class TestAudio(unittest.TestCase):
|
|||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
||||
ortx_stft = PyOrtFunction.from_model(_create_test_model())
|
||||
ortx_stft = PyOrtFunction.from_model(_create_test_model(), cpu_only=True)
|
||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
|
||||
|
@ -109,7 +109,7 @@ class TestAudio(unittest.TestCase):
|
|||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
||||
ortx_stft = PyOrtFunction.from_customop("StftNorm")
|
||||
ortx_stft = PyOrtFunction.from_customop("StftNorm", cpu_only=True)
|
||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||
|
|
|
@ -135,7 +135,7 @@ def _torch_export(*arg, **kwargs):
|
|||
with io.BytesIO() as f:
|
||||
torch.onnx.export(*arg, f, **kwargs)
|
||||
return onnx.load_from_string(f.getvalue())
|
||||
|
||||
|
||||
|
||||
def preprocessing(audio_data):
|
||||
if USE_AUDIO_DECODER:
|
||||
|
@ -179,7 +179,7 @@ def preprocessing(audio_data):
|
|||
return result
|
||||
|
||||
|
||||
def merge_models(core: str, output_model:str, audio_data):
|
||||
def merge_models(core: str, output_model: str, audio_data):
|
||||
m_pre = onnx.load_model("whisper_codec_pre.onnx" if USE_AUDIO_DECODER else "whisper_pre.onnx")
|
||||
m_core = onnx.load_model(core)
|
||||
m1 = onnx.compose.merge_models(m_pre, m_core, io_map=[("log_mel", "input_features")])
|
||||
|
@ -236,7 +236,7 @@ if __name__ == '__main__':
|
|||
print(f"{onnx_model_name}")
|
||||
|
||||
model_name = "openai/" + onnx_model_name[:-len("_beamsearch.onnx")]
|
||||
|
||||
|
||||
_processor = WhisperProcessor.from_pretrained(model_name)
|
||||
# The model similar to Huggingface model like:
|
||||
# model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
||||
|
@ -253,7 +253,7 @@ if __name__ == '__main__':
|
|||
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
|
||||
else:
|
||||
audio_blob, _ = librosa.load(test_file)
|
||||
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
|
||||
audio_blob = np.expand_dims(audio_blob, axis=0) # add a batch_size dimension
|
||||
|
||||
log_mel = preprocessing(audio_blob)
|
||||
print(log_mel.shape)
|
||||
|
|
Загрузка…
Ссылка в новой задаче