Add a new API for building data processing graph from Huggingface transformers processor/tokenizer (#482)
* initial checkins * test pass * basic impl * first unit test pass * merge error * refine a little bit * add more unit test * fix unit test * Fix the unit test. * add one more whisper audiodecoder test case * update the docs * More updates
This commit is contained in:
Родитель
c5e7472070
Коммит
981cb049ff
47
README.md
47
README.md
|
@ -15,7 +15,7 @@ pip install onnxruntime-extensions
|
|||
````
|
||||
|
||||
|
||||
### **nightly build**
|
||||
### **Nightly Build**
|
||||
|
||||
#### <strong>on Windows</strong>
|
||||
```cmd
|
||||
|
@ -23,7 +23,7 @@ pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_pa
|
|||
```
|
||||
Please ensure that you have met the prerequisites of onnxruntime-extensions (e.g., onnx and onnxruntime) in your Python environment.
|
||||
#### <strong>on Linux/macOS</strong>
|
||||
the packages are not ready yet, so it could be installed from source. Please make sure the compiler toolkit like gcc(later than g++ 8.0) or clang, and the tool <strong>cmake</strong> are installed before the following command
|
||||
Please make sure the compiler toolkit like gcc(later than g++ 8.0) or clang are installed before the following command
|
||||
```bash
|
||||
python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.git
|
||||
```
|
||||
|
@ -31,12 +31,16 @@ python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.gi
|
|||
|
||||
## Usage
|
||||
|
||||
## 1. Augment an ONNX model with a pre- and post-processing pipeline
|
||||
check [tutorial](./tutorials) for a couple of examples on how to do it.
|
||||
## 1. Generate the pre-/post- processing ONNX model
|
||||
With onnxruntime-extensions Python package, you can easily get the ONNX processing graph by converting them from Huggingface transformer data processing classes, check the following API for details.
|
||||
```python
|
||||
help(onnxruntime_extensions.gen_processing_models)
|
||||
```
|
||||
### NOTE: These data processing model can be merged into other model [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed.
|
||||
## 2. Using Extensions for ONNX Runtime inference
|
||||
|
||||
### Python
|
||||
|
||||
There are individual packages for the following languages, please install it for the build.
|
||||
```python
|
||||
import onnxruntime as _ort
|
||||
from onnxruntime_extensions import get_library_path as _lib_path
|
||||
|
@ -67,34 +71,13 @@ var sess_opt = new OrtSession.SessionOptions();
|
|||
sess_opt.registerCustomOpLibrary(OrtxPackage.getLibraryPath());
|
||||
```
|
||||
|
||||
## Use exporters to generate graphs with custom operators
|
||||
|
||||
The PyTorch and TensorFlow converters support custom operator generation if the operation from the original framework cannot be interpreted as a standard ONNX operators. Check the following two examples on how to do this.
|
||||
|
||||
1. [CustomOp conversion by pytorch.onnx.exporter](https://github.com/microsoft/onnxruntime-extensions/blob/main/tutorials/pytorch_custom_ops_tutorial.ipynb)
|
||||
2. [CustomOp conversion by tf2onnx](https://github.com/microsoft/onnxruntime-extensions/blob/main/tutorials/tf2onnx_custom_ops_tutorial.ipynb)
|
||||
|
||||
|
||||
## Add a new custom operator to onnxruntime-extensions
|
||||
|
||||
You can contribute customop C++ implementations directly in this repository if they have general applicability to other users. In addition, if you want to quickly verify the ONNX model with Python, you can wrap the custom operator with **[PyOp](docs/pyop.md)**.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
from onnxruntime_extensions import PyOp, onnx_op
|
||||
|
||||
# Implement the CustomOp by decorating a function with onnx_op
|
||||
@onnx_op(op_type="Inverse", inputs=[PyOp.dt_float])
|
||||
def inverse(x):
|
||||
# the user custom op implementation here:
|
||||
return numpy.linalg.inv(x)
|
||||
|
||||
# Run the model with this custom op
|
||||
# model_func = PyOrtFunction(model_path)
|
||||
# outputs = model_func(inputs)
|
||||
# ...
|
||||
### C#
|
||||
```C#
|
||||
SessionOptions options = new SessionOptions()
|
||||
options.RegisterOrtExtensions()
|
||||
session = new InferenceSession(model, options)
|
||||
```
|
||||
Check [development.md](./docs/development.md) for build and test
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ else()
|
|||
message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}")
|
||||
|
||||
# 1.15.1 is the latest ORT release.
|
||||
set(ONNXRUNTIME_VER "1.15.1" CACHE STRING "ONNX Runtime version")
|
||||
set(ONNXRUNTIME_VER "1.15.1")
|
||||
|
||||
if(APPLE)
|
||||
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz")
|
||||
|
|
|
@ -3,16 +3,36 @@
|
|||
Before implement a custom operator, you get the ONNX model with one or more ORT custom operators, created by ONNX converters, [ONNX-Script](https://github.com/microsoft/onnx-script), or [ONNX model API](https://onnx.ai/onnx/api/helper.html) and etc..
|
||||
|
||||
|
||||
## 1. Generate the C++ template code of the Custom operator from the ONNX Model (optional)
|
||||
## 1. Quick verification with PythonOp (optional)
|
||||
|
||||
Before you actually develop a custom operator for the work, if you want to quickly verify the ONNX model with Python, you can wrap the custom operator with **[PyOp](docs/pyop.md)**.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
from onnxruntime_extensions import PyOp, onnx_op
|
||||
|
||||
# Implement the CustomOp by decorating a function with onnx_op
|
||||
@onnx_op(op_type="Inverse", inputs=[PyOp.dt_float])
|
||||
def inverse(x):
|
||||
# the user custom op implementation here:
|
||||
return numpy.linalg.inv(x)
|
||||
|
||||
# Run the model with this custom op
|
||||
# model_func = PyOrtFunction(model_path)
|
||||
# outputs = model_func(inputs)
|
||||
# ...
|
||||
```
|
||||
|
||||
## 2. Generate the C++ template code of the Custom operator from the ONNX Model (optional)
|
||||
python -m onnxruntime-extensions.cmd --cpp-gen <model_path> <repository_dir>`
|
||||
If you are familiar with the ONNX model detail, you create the custom operator C++ classes directly.
|
||||
|
||||
|
||||
## 2. Implement the CustomOp Kernel Compute method in the generated C++ files.
|
||||
the custom operator kernel C++ code exmaple can be found [operators](../operators/) folder, like [KernelGaussianBlur](../operators/cv2/gaussian_blur.hpp). All C++ APIs that can be used in the kernel implmentation are listed below
|
||||
## 3. Implement the CustomOp Kernel Compute method in the generated C++ files.
|
||||
the custom operator kernel C++ code example can be found [operators](../operators/) folder, like [gaussian_blur](../operators/cv2/imgproc/gaussian_blur.hpp). All C++ APIs that can be used in the kernel implementation are listed below
|
||||
|
||||
* [ONNXRuntime Custom API docs](https://onnxruntime.ai/docs/api/c/struct_ort_custom_op.html)
|
||||
* the third libraries API docs intergrated in ONNXRuntime Extensions the can be used in C++ code
|
||||
* the third libraries API docs integrated in ONNXRuntime Extensions the can be used in C++ code
|
||||
- OpenCV API docs https://docs.opencv.org/4.x/
|
||||
- Google SentencePiece Library docs https://github.com/google/sentencepiece/blob/master/doc/api.md
|
||||
- dlib(matrix and ML library) C++ API docs http://dlib.net/algorithms.html
|
||||
|
@ -22,6 +42,6 @@ the custom operator kernel C++ code exmaple can be found [operators](../operator
|
|||
|
||||
## 3. Build and Test
|
||||
- The unit tests can be implemented as Python or C++, check [test](../test) folder for more examples
|
||||
- Check [build-package](./development.md) on how to build the different langauage package to be used for production.
|
||||
- Check [build-package](./development.md) on how to build the different language package to be used for production.
|
||||
|
||||
Please check the [contribution](../README.md#contributing) to see if it is possible to contribute the custom operator to onnxruntime-extensions.
|
||||
|
|
|
@ -21,7 +21,7 @@ from ._ocos import default_opset_domain # noqa
|
|||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
|
||||
from .cvt import gen_processing_models
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
_cuops.py: Custom operators signatures for Python usage.
|
||||
"""
|
||||
|
||||
import onnx
|
||||
import numpy
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
@ -93,12 +97,12 @@ class BpeDecoder(CustomOp):
|
|||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [
|
||||
cls.io_def("ids", onnx.TensorProto.INT64, [])
|
||||
cls.io_def("ids", onnx.TensorProto.INT64, None)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [])]
|
||||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, None)]
|
||||
|
||||
|
||||
class VectorToString(CustomOp):
|
||||
|
@ -423,39 +427,6 @@ class StftNorm(CustomOp):
|
|||
]
|
||||
|
||||
|
||||
class SingleOpGraph:
|
||||
|
||||
@classmethod
|
||||
def get_next_id(cls):
|
||||
if not hasattr(cls, '_id_counter'):
|
||||
cls._id_counter = 0
|
||||
cls._id_counter += 1
|
||||
return cls._id_counter
|
||||
|
||||
@classmethod
|
||||
def build_my_graph(cls, op_class, *args, **kwargs):
|
||||
if isinstance(op_class, str):
|
||||
op_class = cls.get_op_class(op_class)
|
||||
|
||||
op_type = op_class.op_type()
|
||||
inputs = op_class.get_inputs()
|
||||
outputs = op_class.get_outputs()
|
||||
attrs = op_class.serialize_attr(kwargs)
|
||||
cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
|
||||
[o_.name for o_ in outputs],
|
||||
"{}_{}".format(op_type,
|
||||
cls.get_next_id()),
|
||||
**attrs,
|
||||
domain=default_opset_domain())
|
||||
graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
|
||||
op_type, cls.get_next_id()), inputs, outputs)
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def get_op_class(op_type):
|
||||
return globals()[op_type]
|
||||
|
||||
|
||||
# TODO: have a C++ impl.
|
||||
def _argsort_op(x, dim):
|
||||
d = numpy.argsort(x, dim)
|
||||
|
@ -470,3 +441,43 @@ Opdef.create(_argsort_op,
|
|||
|
||||
class CustomOpConverter:
|
||||
pass
|
||||
|
||||
|
||||
class SingleOpGraph:
|
||||
|
||||
@classmethod
|
||||
def get_next_id(cls):
|
||||
if not hasattr(cls, '_id_counter'):
|
||||
cls._id_counter = 0
|
||||
cls._id_counter += 1
|
||||
return cls._id_counter
|
||||
|
||||
@classmethod
|
||||
def build_graph(cls, op_class, *args, **kwargs):
|
||||
if isinstance(op_class, str):
|
||||
op_class = cls.get_op_class(op_class)
|
||||
|
||||
cvt = kwargs.pop('cvt', None)
|
||||
if cvt is None and len(args) > 0 and isinstance(args[0], CustomOpConverter):
|
||||
cvt = args[0]
|
||||
args = args[1:]
|
||||
|
||||
new_kwargs = kwargs if cvt is None else cvt(**kwargs)
|
||||
|
||||
op_type = op_class.op_type()
|
||||
inputs = op_class.get_inputs()
|
||||
outputs = op_class.get_outputs()
|
||||
attrs = op_class.serialize_attr(new_kwargs)
|
||||
cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
|
||||
[o_.name for o_ in outputs],
|
||||
"{}_{}".format(op_type,
|
||||
cls.get_next_id()),
|
||||
**attrs,
|
||||
domain=default_opset_domain())
|
||||
graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
|
||||
op_type, cls.get_next_id()), inputs, outputs)
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def get_op_class(op_type):
|
||||
return globals()[op_type]
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
_hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
||||
"""
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||
from .util import read_file
|
||||
|
||||
|
||||
class HFTokenizerConverter(CustomOpConverter):
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bpe_decoder(self, **kwargs):
|
||||
decoder = self.tokenizer.decoder
|
||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||
# with open("id_vocab.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(id_vocab)
|
||||
byte_decoder = self.tokenizer.byte_decoder
|
||||
str_byte_decoder = "\n".join(["{}\t{}".format(
|
||||
ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
|
||||
# with open("byte_decoder.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(str_byte_decoder)
|
||||
all_special_ids = self.tokenizer.all_special_ids
|
||||
added_tokens = self.tokenizer.added_tokens_decoder
|
||||
str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
|
||||
str_added_tokens = "\n".join(
|
||||
["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens])
|
||||
kwargs.update({
|
||||
"id_vocab": id_vocab,
|
||||
"byte_decoder": str_byte_decoder,
|
||||
"added_tokens": str_added_tokens,
|
||||
"all_special_ids": str_all_special_ids,
|
||||
"skip_special_tokens": kwargs.get("skip_special_tokens", False)
|
||||
})
|
||||
return kwargs
|
||||
|
||||
def clip_tokenizer(self, **kwargs):
|
||||
hf_clip_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_clip_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_clip_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def roberta_tokenizer(self, **kwargs):
|
||||
hf_roberta_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def t5_tokenizer(self, **kwargs):
|
||||
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def t5_decoder(self, **kwargs):
|
||||
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
||||
_PROCESSOR_DICT = {
|
||||
"GPT2Tokenizer": ('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"ClipTokenizer": ('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"RobertaTokenizer": ("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None),
|
||||
"T5Tokenizer": ("SentencepieceTokenizer", HFTokenizerConverter.t5_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.t5_decoder),
|
||||
}
|
||||
|
||||
|
||||
class HFTokenizerOnnxGraph:
|
||||
@staticmethod
|
||||
def extract_cls_name(processor):
|
||||
cls_name = processor if isinstance(processor, str) else type(processor).__name__
|
||||
if cls_name.endswith("TokenizerFast"):
|
||||
cls_name = cls_name[:-len("Fast")]
|
||||
return cls_name
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, processor):
|
||||
cls_name = cls.extract_cls_name(processor)
|
||||
return cls_name in _PROCESSOR_DICT
|
||||
|
||||
def __init__(self, processor, **kwargs):
|
||||
cls_name = self.extract_cls_name(processor)
|
||||
self.cvt_quadruple = _PROCESSOR_DICT[cls_name]
|
||||
self.cvt_obj = HFTokenizerConverter(processor)
|
||||
|
||||
def pre_processing(self, **kwargs):
|
||||
_cvt_op = self.cvt_quadruple[0]
|
||||
_cvt_func = self.cvt_quadruple[1]
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
_cvt_op = self.cvt_quadruple[2]
|
||||
_cvt_func = self.cvt_quadruple[3]
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
|
@ -2,6 +2,9 @@
|
|||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
"""
|
||||
_ocos.py: PythonOp implementation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import copy
|
||||
|
|
|
@ -3,9 +3,14 @@
|
|||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
_ortapi2.py: ONNXRuntime-Extensions Python API
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||
from ._cuops import onnx, onnx_proto, CustomOpConverter, SingleOpGraph
|
||||
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
||||
|
||||
|
||||
_ort_check_passed = False
|
||||
try:
|
||||
|
@ -54,12 +59,10 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
|||
|
||||
|
||||
class OrtPyFunction:
|
||||
|
||||
@classmethod
|
||||
def get_ort_session_options(cls):
|
||||
# ONNXRuntime has an issue to support reusing the SessionOptions object.
|
||||
# Create a new one every time here
|
||||
def get_ort_session_options(self):
|
||||
so = _ort.SessionOptions()
|
||||
for k, v in self.extra_session_options.items():
|
||||
so.__setattr__(k, v)
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
return so
|
||||
|
||||
|
@ -71,18 +74,10 @@ class OrtPyFunction:
|
|||
if not cpu_only:
|
||||
if _ort.get_device() == 'GPU':
|
||||
self.execution_providers = ['CUDAExecutionProvider']
|
||||
self.extra_session_options = {}
|
||||
|
||||
def create_from_customop(self, op_type, *args, **kwargs):
|
||||
cvt = kwargs.get('cvt', None)
|
||||
if cvt is None:
|
||||
cvt = args[0] if len(args) > 0 and isinstance(
|
||||
args[0], CustomOpConverter) else None
|
||||
args = args[1:]
|
||||
else:
|
||||
del kwargs['cvt']
|
||||
|
||||
new_kwargs = kwargs if cvt is None else cvt(**kwargs)
|
||||
graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs)
|
||||
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
|
||||
self._bind(make_onnx_model(graph))
|
||||
return self
|
||||
|
||||
|
|
|
@ -0,0 +1,233 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
_torch_cvt.py: Data processing graph converted from PyTorch
|
||||
"""
|
||||
|
||||
import io
|
||||
import onnx
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from onnx import numpy_helper
|
||||
|
||||
from ._ortapi2 import make_onnx_model
|
||||
from ._cuops import SingleOpGraph
|
||||
from ._hf_cvt import HFTokenizerConverter
|
||||
from .util import remove_unused_initializers
|
||||
|
||||
|
||||
class _WhisperHParams:
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = N_SAMPLES // HOP_LENGTH
|
||||
|
||||
|
||||
def _mel_filterbank(
|
||||
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
|
||||
"""
|
||||
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
|
||||
and it is Slaney normalized mel-scale filterbank.
|
||||
"""
|
||||
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
|
||||
|
||||
# the centers of the frequency bins for the DFT
|
||||
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
mel = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mel
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
log_t = mel >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
|
||||
mel_bins = freqs
|
||||
|
||||
mel_spacing = np.diff(mel_bins)
|
||||
|
||||
ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
|
||||
for i in range(n_mels):
|
||||
left = -ramps[i] / mel_spacing[i]
|
||||
right = ramps[i + 2] / mel_spacing[i + 1]
|
||||
|
||||
# intersect them with each other and zero
|
||||
fbank[i] = np.maximum(0, np.minimum(left, right))
|
||||
|
||||
energy_norm = 2.0 / (mel_bins[2 : n_mels + 2] - mel_bins[:n_mels])
|
||||
fbank *= energy_norm[:, np.newaxis]
|
||||
return fbank
|
||||
|
||||
|
||||
class CustomOpStftNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, self, n_fft, hop_length, window):
|
||||
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
|
||||
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, audio, n_fft, hop_length, window):
|
||||
win_length = window.shape[0]
|
||||
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
|
||||
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
||||
return stft.abs() ** 2
|
||||
|
||||
|
||||
class WhisperPrePipeline(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.window = torch.hann_window(_WhisperHParams.N_FFT)
|
||||
self.mel_filters = torch.from_numpy(
|
||||
_mel_filterbank(
|
||||
sr=_WhisperHParams.SAMPLE_RATE,
|
||||
n_fft=_WhisperHParams.N_FFT,
|
||||
n_mels=_WhisperHParams.N_MELS))
|
||||
|
||||
def forward(self, audio_pcm: torch.Tensor):
|
||||
stft_norm = CustomOpStftNorm.apply(audio_pcm,
|
||||
_WhisperHParams.N_FFT,
|
||||
_WhisperHParams.HOP_LENGTH,
|
||||
self.window)
|
||||
magnitudes = stft_norm[:, :, :-1]
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
spec_min = log_spec.max() - 8.0
|
||||
log_spec = torch.maximum(log_spec, spec_min)
|
||||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1], (
|
||||
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH -
|
||||
spec_shape[2]), dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
def _to_onnx_stft(onnx_model):
|
||||
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
||||
node_idx = 0
|
||||
new_stft_nodes = []
|
||||
stft_norm_node = None
|
||||
for node in onnx_model.graph.node:
|
||||
if node.op_type == "StftNorm":
|
||||
stft_norm_node = node
|
||||
break
|
||||
node_idx += 1
|
||||
|
||||
if stft_norm_node is None:
|
||||
raise RuntimeError("Cannot find STFTNorm node in the graph")
|
||||
|
||||
make_node = onnx.helper.make_node
|
||||
replaced_nodes = [
|
||||
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
||||
value=numpy_helper.from_array(np.array([0,
|
||||
_WhisperHParams.N_FFT // 2, 0,
|
||||
_WhisperHParams.N_FFT // 2], dtype='int64'),
|
||||
name='const_14')),
|
||||
make_node('Pad',
|
||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
||||
outputs=['pad_1_output_0'], mode='reflect'),
|
||||
make_node('STFT',
|
||||
inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
|
||||
outputs=['stft_output_0'], name='stft', domain='', onesided=1),
|
||||
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
|
||||
perm=[0, 2, 1, 3]),
|
||||
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
|
||||
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
|
||||
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
|
||||
value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
|
||||
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
|
||||
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
|
||||
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
|
||||
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
||||
name='slice_1'),
|
||||
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
|
||||
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
||||
name='gather_4', axis=3),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
||||
name='gather_5', axis=3),
|
||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
|
||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
|
||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
|
||||
]
|
||||
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
||||
new_stft_nodes.extend(replaced_nodes)
|
||||
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
|
||||
del onnx_model.graph.node[:]
|
||||
onnx_model.graph.node.extend(new_stft_nodes)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def _torch_export(*arg, **kwargs):
|
||||
with io.BytesIO() as f:
|
||||
torch.onnx.export(*arg, f, **kwargs)
|
||||
return onnx.load_from_string(f.getvalue())
|
||||
|
||||
|
||||
class WhisperDataProcGraph:
|
||||
def __init__(self, processor, **kwargs):
|
||||
self.hf_processor = processor
|
||||
_opset = kwargs.pop('opset', 17)
|
||||
self.opset_version = _opset if _opset else 17
|
||||
|
||||
def pre_processing(self, **kwargs):
|
||||
use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
|
||||
use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
|
||||
whisper_processing = WhisperPrePipeline()
|
||||
|
||||
audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
|
||||
model_args = (audio_pcm,)
|
||||
pre_model = _torch_export(
|
||||
whisper_processing,
|
||||
model_args,
|
||||
input_names=["audio_pcm"],
|
||||
output_names=["log_mel"],
|
||||
do_constant_folding=True,
|
||||
export_params=True,
|
||||
opset_version=self.opset_version,
|
||||
dynamic_axes={
|
||||
"audio_pcm": {1: "sample_len"},
|
||||
}
|
||||
)
|
||||
if use_onnx_stft:
|
||||
pre_model = _to_onnx_stft(pre_model)
|
||||
remove_unused_initializers(pre_model.graph)
|
||||
|
||||
pre_full = pre_model
|
||||
if use_audio_decoder:
|
||||
audecoder_g = SingleOpGraph.build_graph(
|
||||
"AudioDecoder", downsampling_rate=_WhisperHParams.SAMPLE_RATE, stereo_to_mono=1)
|
||||
audecoder_m = make_onnx_model(audecoder_g)
|
||||
pre_full = onnx.compose.merge_models(
|
||||
audecoder_m,
|
||||
pre_model,
|
||||
io_map=[("floatPCM", "audio_pcm")])
|
||||
|
||||
return pre_full
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
g = SingleOpGraph.build_graph(
|
||||
"BpeDecoder",
|
||||
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
|
||||
skip_special_tokens=True,
|
||||
cpu_only=True)
|
||||
return make_onnx_model(g)
|
|
@ -1,3 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
cmd.py: cli commands for onnxruntime_extensions
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import onnx
|
||||
|
|
|
@ -1,65 +1,71 @@
|
|||
import json
|
||||
from ._cuops import CustomOpConverter
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
cvt.py: Processing Graph Converter and Generator
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
|
||||
from ._ortapi2 import make_onnx_model
|
||||
|
||||
|
||||
class HFTokenizerConverter(CustomOpConverter):
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
_is_torch_available = False
|
||||
try:
|
||||
import torch # noqa
|
||||
_is_torch_available = True
|
||||
from ._torch_cvt import WhisperDataProcGraph
|
||||
except ImportError:
|
||||
WhisperDataProcGraph = None
|
||||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bpe_decoder(self, **kwargs):
|
||||
decoder = self.tokenizer.decoder
|
||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||
# with open("id_vocab.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(id_vocab)
|
||||
byte_decoder = self.tokenizer.byte_decoder
|
||||
str_byte_decoder = "\n".join(["{}\t{}".format(
|
||||
ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
|
||||
# with open("byte_decoder.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(str_byte_decoder)
|
||||
all_special_ids = self.tokenizer.all_special_ids
|
||||
added_tokens = self.tokenizer.added_tokens_decoder
|
||||
str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
|
||||
str_added_tokens = "\n".join(
|
||||
["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens])
|
||||
kwargs.update({
|
||||
"id_vocab": id_vocab,
|
||||
"byte_decoder": str_byte_decoder,
|
||||
"added_tokens": str_added_tokens,
|
||||
"all_special_ids": str_all_special_ids,
|
||||
"skip_special_tokens": kwargs.get("skip_special_tokens", False)
|
||||
})
|
||||
def gen_processing_models(processor: Union[str, object],
|
||||
pre_kwargs: dict = None,
|
||||
post_kwargs: dict = None,
|
||||
opset: int = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Generate the pre- and post-processing ONNX model, basing on the name or HF class.
|
||||
|
||||
return kwargs
|
||||
Parameters
|
||||
----------
|
||||
processor:
|
||||
the HF processor/tokenizer instance, or the name (str) of a Data Processor
|
||||
the instance is preferred, otherwise when name was given, the corresponding configuration for the processor
|
||||
has to be provided in the kwargs
|
||||
pre_kwargs: dict
|
||||
Keyword arguments for generating the pre-processing model
|
||||
post_kwargs: dict
|
||||
Keyword arguments for generating the post-processing model
|
||||
opset: int
|
||||
the target opset version of the model
|
||||
kwargs:
|
||||
The additional arguments for generating models
|
||||
|
||||
def clip_tokenizer(self, **kwargs):
|
||||
hf_clip_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_clip_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_clip_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
Returns
|
||||
-------
|
||||
ONNX-Models
|
||||
The pre- and post-processing ONNX models
|
||||
"""
|
||||
if pre_kwargs is None and post_kwargs is None:
|
||||
raise ValueError("Either pre_kwargs or post_kwargs should be provided. None means no processing")
|
||||
|
||||
def roberta_tokenizer(self, **kwargs):
|
||||
hf_roberta_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
cls_name = processor if isinstance(processor, str) else type(processor).__name__
|
||||
if cls_name == "WhisperProcessor":
|
||||
if WhisperDataProcGraph is None:
|
||||
raise ValueError("The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
|
||||
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
|
||||
pre_m = _converter.pre_processing(**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_m = _converter.post_processing(**post_kwargs) if post_kwargs is not None else None
|
||||
return pre_m, post_m
|
||||
elif HFTokenizerOnnxGraph.is_supported(processor):
|
||||
_converter = HFTokenizerOnnxGraph(processor)
|
||||
pre_g = _converter.pre_processing(**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_g = _converter.post_processing(**post_kwargs) if post_kwargs is not None else None
|
||||
return make_onnx_model(pre_g) if pre_g else None, \
|
||||
make_onnx_model(post_g) if post_g else None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
util.py: Miscellaneous utility functions
|
||||
"""
|
||||
|
||||
import onnx
|
||||
import pathlib
|
||||
import inspect
|
||||
|
@ -22,7 +27,7 @@ def read_file(path, mode='r'):
|
|||
def mel_filterbank(
|
||||
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
|
||||
"""
|
||||
Compute a Mel-filterbank. The filters are stored in the rows, the columns
|
||||
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
|
||||
and it is Slaney normalized mel-scale filterbank.
|
||||
"""
|
||||
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
|
||||
|
|
|
@ -146,8 +146,8 @@ struct AudioDecoder : public BaseKernel {
|
|||
}
|
||||
|
||||
if (downsample_rate_ != 0 &&
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: only down sampling supported.", ORT_INVALID_ARGUMENT);
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// join all frames
|
||||
|
@ -172,7 +172,7 @@ struct AudioDecoder : public BaseKernel {
|
|||
if (downsample_rate_ != 0 &&
|
||||
downsample_rate_ != orig_sample_rate) {
|
||||
// A lowpass filter on buf audio data to remove high frequency noise
|
||||
ButterworthLowpass filter(1.0 * orig_sample_rate, 0.5 * downsample_rate_);
|
||||
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
|
||||
std::vector<float> filtered_buf = filter.Process(buf);
|
||||
// downsample the audio data
|
||||
KaiserWindowInterpolation::Process(filtered_buf, buf,
|
||||
|
|
1
setup.py
1
setup.py
|
@ -87,7 +87,6 @@ class BuildCMakeExt(_build_ext):
|
|||
if os.environ.get('OCOS_NO_OPENCV') == '1':
|
||||
# Disabling openCV can drastically reduce the build time.
|
||||
cmake_args += [
|
||||
'-DOCOS_ENABLE_CTEST=OFF',
|
||||
'-DOCOS_ENABLE_OPENCV_CODECS=OFF',
|
||||
'-DOCOS_ENABLE_CV2=OFF',
|
||||
'-DOCOS_ENABLE_VISION=OFF']
|
||||
|
|
|
@ -200,8 +200,9 @@ void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::v
|
|||
OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||
output.resize(len);
|
||||
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
||||
if (i < len - 1)
|
||||
if (i < static_cast<int64_t>(len) - 1) {
|
||||
result[offsets[i + (int64_t)1]] = '\0';
|
||||
}
|
||||
output[i] = result.data() + offsets[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#ifdef ENABLE_CV2
|
||||
#include "gtest/gtest.h"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
|
||||
|
@ -65,3 +67,5 @@ TEST(VisionOps, image_decode_encode) {
|
|||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import transformers as _hfts
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
from packaging import version
|
||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||
|
||||
|
||||
class TestAutoTokenizer(unittest.TestCase):
|
||||
def test_t5_tokenizer(self):
|
||||
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
|
||||
print(ids)
|
||||
|
||||
alpha = 0
|
||||
nbest_size = 0
|
||||
flags = 0
|
||||
|
||||
t5_default_inputs = (
|
||||
np.array(
|
||||
[nbest_size], dtype=np.int64),
|
||||
np.array([alpha], dtype=np.float32),
|
||||
np.array([flags & 1], dtype=np.bool_),
|
||||
np.array([flags & 2], dtype=np.bool_),
|
||||
np.array([flags & 4], dtype=np.bool_))
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
|
||||
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
def test_whisper(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, post_m = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
t = np.linspace(0, 2*np.pi, 480000).astype(np.float32)
|
||||
simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0)
|
||||
log_mel = fn_pre(simaudio)
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
fn_post = OrtPyFunction.from_model(post_m)
|
||||
rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32))
|
||||
self.assertEqual(rel[0], "$%&")
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
def test_whisper_audio_decoder(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
|
||||
fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
|
||||
test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
|
||||
raw_audio = np.fromfile(test_flac_file, dtype=np.uint8)
|
||||
log_mel = fn_pre(np.expand_dims(raw_audio, axis=0))
|
||||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче