add sentencepiece pre-trained model test (#110)
This commit is contained in:
Родитель
cb81344392
Коммит
f4e1be286a
|
@ -5,7 +5,7 @@
|
|||
|
||||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||
from ._ocos import default_opset_domain
|
||||
|
||||
|
||||
class CustomOp:
|
||||
|
@ -88,9 +88,25 @@ class StringToVector(CustomOp):
|
|||
return attr_data
|
||||
|
||||
|
||||
# TODO: list all custom operators schema here:
|
||||
# ...
|
||||
# ...
|
||||
class SentencepieceTokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [
|
||||
cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
|
||||
cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
|
||||
cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
|
||||
cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
|
||||
cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [
|
||||
cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
|
||||
cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
|
||||
|
||||
]
|
||||
|
||||
|
||||
class SingleOpGraph:
|
||||
|
@ -103,6 +119,9 @@ class SingleOpGraph:
|
|||
|
||||
@classmethod
|
||||
def build_my_graph(cls, op_class, *args, **kwargs):
|
||||
if isinstance(op_class, str):
|
||||
op_class = cls.get_op_class(op_class)
|
||||
|
||||
op_type = op_class.op_type()
|
||||
inputs = op_class.get_inputs()
|
||||
outputs = op_class.get_outputs()
|
||||
|
|
|
@ -34,6 +34,7 @@ class EagerOp:
|
|||
def __init__(self):
|
||||
self._onnx_model = None
|
||||
self.ort_session = None
|
||||
self.default_inputs = {}
|
||||
|
||||
def create_from_customop(self, op_type, *args, **kwargs):
|
||||
graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs)
|
||||
|
@ -43,6 +44,14 @@ class EagerOp:
|
|||
self._bind(model)
|
||||
return self
|
||||
|
||||
def add_default_input(self, **kwargs):
|
||||
inputs = {
|
||||
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \
|
||||
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
|
||||
}
|
||||
|
||||
self.default_inputs.update(inputs)
|
||||
|
||||
@property
|
||||
def onnx_model(self):
|
||||
assert self._oxml is not None, "No onnx model attached yet."
|
||||
|
@ -81,6 +90,10 @@ class EagerOp:
|
|||
idx = 0
|
||||
feed = {}
|
||||
for i_ in self.inputs:
|
||||
if i_.name in self.default_inputs:
|
||||
feed[i_.name] = self.default_inputs[i_.name]
|
||||
continue
|
||||
|
||||
x = args[idx]
|
||||
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
||||
# an annoying bug is numpy by default is int32, while pytorch is int64.
|
||||
|
|
Двоичный файл не отображается.
|
@ -7,7 +7,7 @@ from numpy.testing import assert_almost_equal
|
|||
from onnx import helper, onnx_pb as onnx_proto
|
||||
import onnxruntime as _ort
|
||||
from onnxruntime_extensions import (
|
||||
onnx_op, PyCustomOpDef,
|
||||
onnx_op, PyCustomOpDef, PyOrtFunction,
|
||||
get_library_path as _get_library_path)
|
||||
import tensorflow as tf
|
||||
from tensorflow_text import SentencepieceTokenizer
|
||||
|
@ -429,6 +429,24 @@ class TestPythonOpSentencePiece(unittest.TestCase):
|
|||
assert_almost_equal(exp[i], py_txout[i])
|
||||
assert_almost_equal(exp[i], cc_txout[i])
|
||||
|
||||
def test_external_pretrained_model(self):
|
||||
fullname = os.path.join(
|
||||
os.path.dirname(__file__), 'data', 'en.wiki.bpe.vs100000.model')
|
||||
ofunc = PyOrtFunction.from_customop("SentencepieceTokenizer", model=open(fullname, 'rb').read())
|
||||
|
||||
alpha = 0
|
||||
nbest_size = 0
|
||||
flags = 0
|
||||
tokens, indices = ofunc(
|
||||
np.array(['best hotel in bay area.']),
|
||||
np.array(
|
||||
[nbest_size], dtype=np.int64),
|
||||
np.array([alpha], dtype=np.float32),
|
||||
np.array([flags & 1], dtype=np.bool_),
|
||||
np.array([flags & 2], dtype=np.bool_),
|
||||
np.array([flags & 4], dtype=np.bool_))
|
||||
self.assertEquals(tokens.tolist(), [1095, 4054, 26, 2022, 755, 99935])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче