An end-to-end BERT model with pre-/post- processing. (#224)
* bert demo * add some comments * support multiple outputs in ONNX model * code polishing * encoding issue on Windows platform.
This commit is contained in:
Родитель
fb378b72a0
Коммит
bfbfa5a304
|
@ -32,4 +32,4 @@ PyOp = PyCustomOpDef
|
|||
def get_test_data_file(*sub_dirs):
|
||||
case_file = inspect.currentframe().f_back.f_code.co_filename
|
||||
test_dir = pathlib.Path(case_file).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
return str(test_dir.joinpath(*sub_dirs).resolve())
|
||||
|
|
|
@ -80,7 +80,7 @@ class OrtPyFunction:
|
|||
|
||||
def _bind(self, oxml):
|
||||
self.inputs = list(oxml.graph.input)
|
||||
self.output = list(oxml.graph.output)
|
||||
self.outputs = list(oxml.graph.output)
|
||||
self._oxml = oxml
|
||||
return self
|
||||
|
||||
|
|
|
@ -10,4 +10,4 @@ from ._torchext import * # noqa
|
|||
from ._unifier import export
|
||||
|
||||
from ._imagenet import * # noqa
|
||||
from ._nlp import PreHuggingFaceGPT2
|
||||
from ._nlp import PreHuggingFaceGPT2, PreHuggingFaceBert # noqa
|
||||
|
|
|
@ -78,12 +78,16 @@ class _ProcessingModule:
|
|||
|
||||
|
||||
class ProcessingTracedModule(torch.nn.Module, _ProcessingModule):
|
||||
pass
|
||||
def __init__(self, func_obj=None):
|
||||
super().__init__()
|
||||
self.func_obj = func_obj
|
||||
|
||||
def forward(self, *args):
|
||||
assert self.func_obj is not None, "No forward method found."
|
||||
return self.func_obj(*args)
|
||||
|
||||
|
||||
class ProcessingScriptModule(torch.nn.Module, _ProcessingModule):
|
||||
def __init__(self):
|
||||
super(ProcessingScriptModule, self).__init__()
|
||||
|
||||
@torch.jit.unused
|
||||
def export(self, *args, **kwargs):
|
||||
|
|
|
@ -12,6 +12,25 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_
|
|||
op_version=1, name=op_name, op_domain=default_opset_domain(), **kwargs)
|
||||
|
||||
|
||||
@schema(inputs=((_dt.STRING, []),),
|
||||
outputs=((_dt.INT64, []), (_dt.INT64, []), (_dt.INT64, [])))
|
||||
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
||||
if 'hf_tok' in kwargs:
|
||||
# TODO: need bert-tokenizer support JSON format
|
||||
hf_bert_tokenizer = kwargs['hf_tok']
|
||||
attrs = {'vocab_file': json.dumps(hf_bert_tokenizer.vocab, separators=(',', ':'))}
|
||||
elif 'vocab_file' in kwargs:
|
||||
attrs = dict(vocab_file=kwargs['vocab_file'])
|
||||
else:
|
||||
raise RuntimeError("Need hf_tok/vocab_file parameter to build the tokenizer")
|
||||
if 'strip_accents' in kwargs:
|
||||
strip_accents = kwargs['strip_accents']
|
||||
attrs['strip_accents'] = strip_accents
|
||||
|
||||
return make_custom_op(ctx, 'BertTokenizer', input_names,
|
||||
output_names, container, operator_name=operator_name, **attrs)
|
||||
|
||||
|
||||
@schema(inputs=((_dt.STRING, []),),
|
||||
outputs=((_dt.INT64, []), (_dt.INT64, [])))
|
||||
def gpt2_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
||||
|
@ -44,6 +63,30 @@ def _get_bound_object(func):
|
|||
return func.__self__
|
||||
|
||||
|
||||
class PreHuggingFaceBert(ProcessingTracedModule):
|
||||
def __init__(self, hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1):
|
||||
super(PreHuggingFaceBert, self).__init__()
|
||||
if hf_tok is None:
|
||||
_vocab = None
|
||||
with open(vocab_file, "r", encoding='utf-8') as vf:
|
||||
lines = vf.readlines()
|
||||
_vocab = '\n'.join(lines)
|
||||
if _vocab is None:
|
||||
raise RuntimeError("Cannot load vocabulary file {}!".format(vocab_file))
|
||||
self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize,
|
||||
vocab_file=_vocab,
|
||||
do_lower_case=do_lower_case,
|
||||
strip_accents=strip_accents)
|
||||
else:
|
||||
self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, hf_tok=self.hf_tok)
|
||||
|
||||
def forward(self, text):
|
||||
return self.onnx_bert_tokenize(text)
|
||||
|
||||
def export(self, *args, **kwargs):
|
||||
return _get_bound_object(self.onnx_bert_tokenize).build_model(kwargs.get('opset_version', 0), *args)
|
||||
|
||||
|
||||
class PreHuggingFaceGPT2(ProcessingTracedModule):
|
||||
def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1):
|
||||
super(PreHuggingFaceGPT2, self).__init__()
|
||||
|
|
|
@ -8,7 +8,7 @@ from distutils.version import LooseVersion
|
|||
from torch.onnx import register_custom_op_symbolic
|
||||
|
||||
from ._utils import ONNXModelUtils
|
||||
from ._base import CustomFunction, ProcessingTracedModule
|
||||
from ._base import CustomFunction, ProcessingTracedModule, is_processing_module
|
||||
from ._onnx_ops import ox as _ox, schema as _schema
|
||||
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
||||
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
|
||||
|
@ -120,42 +120,9 @@ onnx_where = create_op_function('Where', _ox.where)
|
|||
onnx_greater = create_op_function('Greater', _ox.greater)
|
||||
|
||||
|
||||
class PythonOpFunction:
|
||||
"""
|
||||
PythonOpFunction wraps a generic Python function which skips forward operation on torch.onnx.exporter
|
||||
converting in the script mode, since the exporter cannot support the APIs from external package, like Numpy.
|
||||
BTW, Autograd.Function cannot be used torch.jit.script.
|
||||
"""
|
||||
id_func_map = {}
|
||||
current_func_id = 0
|
||||
|
||||
@staticmethod
|
||||
def _get_next_id():
|
||||
PythonOpFunction.current_func_id += 1
|
||||
return PythonOpFunction.current_func_id
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.ignore
|
||||
def _pass_through_call(*args, **kwargs):
|
||||
func_id = args[0]
|
||||
func = PythonOpFunction.id_func_map[func_id]
|
||||
return torch.from_numpy(func.forward(*args[1:], **kwargs))
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
return PythonOpFunction._pass_through_call(cls.get_id(), *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_id(cls):
|
||||
if not hasattr(cls, 'func_id'):
|
||||
_id = PythonOpFunction._get_next_id()
|
||||
setattr(cls, 'func_id', _id)
|
||||
PythonOpFunction.id_func_map[_id] = cls
|
||||
return cls.func_id
|
||||
|
||||
|
||||
class _OnnxModelFunction:
|
||||
id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
|
||||
id_function_map = {}
|
||||
str_model_function_id = '_model_function_id'
|
||||
str_model_id = '_model_id'
|
||||
str_model_attached = '_model_attached'
|
||||
|
@ -163,12 +130,16 @@ class _OnnxModelFunction:
|
|||
|
||||
@torch.jit.ignore
|
||||
def _invoke_onnx_model(model_id: int, *args, **kwargs):
|
||||
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
|
||||
if model_or_path is None:
|
||||
raise ValueError("cannot find id={} registered!".format(model_id))
|
||||
func = OrtPyFunction.from_model(model_or_path)
|
||||
return torch.from_numpy(
|
||||
func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs))
|
||||
func = _OnnxModelFunction.id_function_map.get(model_id, None)
|
||||
if not func:
|
||||
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
|
||||
if model_or_path is None:
|
||||
raise ValueError("cannot find id={} registered!".format(model_id))
|
||||
func = OrtPyFunction.from_model(model_or_path)
|
||||
_OnnxModelFunction.id_function_map[model_id] = func
|
||||
results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)
|
||||
return tuple(
|
||||
[torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
|
@ -193,7 +164,20 @@ class _OnnxTracedFunction(CustomFunction):
|
|||
|
||||
@classmethod
|
||||
def symbolic(cls, g, *args):
|
||||
return g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
|
||||
ret = g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
|
||||
model_id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]) # noqa
|
||||
if not model_id:
|
||||
return ret
|
||||
|
||||
func = _OnnxModelFunction.id_function_map.get(model_id.item(), None)
|
||||
if not func or len(func.outputs) <= 1:
|
||||
return ret
|
||||
|
||||
outputs = [ret]
|
||||
for _ in range(len(func.outputs) - 1):
|
||||
outputs.append(ret.node().addOutput())
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
def create_model_function(model_or_path):
|
||||
|
@ -242,8 +226,11 @@ class SequentialProcessingModule(ProcessingTracedModule):
|
|||
for mdl_ in models:
|
||||
if isinstance(mdl_, onnx.ModelProto):
|
||||
self.model_list.append(_OnnxModelModule(mdl_))
|
||||
else:
|
||||
elif is_processing_module(mdl_):
|
||||
self.model_list.append(mdl_)
|
||||
else:
|
||||
assert callable(mdl_), "the model type is not recognizable."
|
||||
self.model_list.append(ProcessingTracedModule(mdl_))
|
||||
|
||||
def forward(self, *args):
|
||||
outputs = args
|
||||
|
@ -251,7 +238,7 @@ class SequentialProcessingModule(ProcessingTracedModule):
|
|||
for idx_, mdl_ in enumerate(self.model_list):
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
outputs = self.model_list[idx_](*outputs)
|
||||
outputs = mdl_(*outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -133,6 +133,8 @@ class ONNXModelUtils:
|
|||
nest_model,
|
||||
io_mapping)
|
||||
for idx_, out_ in enumerate(nest_model.graph.output):
|
||||
if idx_ >= len(_node.output):
|
||||
continue
|
||||
_renamed_out = "{}_{}".format(prefix, out_.name)
|
||||
_nd = onnx.helper.make_node('Identity',
|
||||
[_renamed_out],
|
||||
|
|
5
setup.py
5
setup.py
|
@ -104,7 +104,10 @@ class BuildCMakeExt(_build_ext):
|
|||
self.spawn(['cmake', '--build', str(build_temp)] + build_args)
|
||||
|
||||
if sys.platform == "win32":
|
||||
self.copy_file(build_temp / config / 'ortcustomops.dll',
|
||||
config_dir = '.'
|
||||
if not (build_temp / 'build.ninja').exists():
|
||||
config_dir = config
|
||||
self.copy_file(build_temp / config_dir / 'ortcustomops.dll',
|
||||
self.get_ext_filename(extension.name))
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import onnx
|
||||
import torch
|
||||
import onnxruntime_extensions
|
||||
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import pnp, OrtPyFunction
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.onnx import export, FeaturesManager
|
||||
|
||||
# get an onnx model by converting HuggingFace pretrained model
|
||||
model_name = "bert-base-cased"
|
||||
model_path = Path("onnx-model/bert-base-cased.onnx")
|
||||
if not model_path.exists():
|
||||
if not model_path.parent.exists():
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
model = FeaturesManager.get_model_from_feature("default", model_name)
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
export(tokenizer,
|
||||
model=model,
|
||||
config=onnx_config,
|
||||
opset=12,
|
||||
output=model_path)
|
||||
|
||||
|
||||
# a silly post-processing example function, demo-purpose only
|
||||
def post_processing_forward(*pred):
|
||||
return torch.softmax(pred[1], dim=1)
|
||||
|
||||
|
||||
# mapping the BertTokenizer outputs into the onnx model inputs
|
||||
def mapping_token_output(_1, _2, _3):
|
||||
return _1.unsqueeze(0), _3.unsqueeze(0), _2.unsqueeze(0)
|
||||
|
||||
|
||||
test_sentence = ["this is a test sentence."]
|
||||
ort_tok = pnp.PreHuggingFaceBert(
|
||||
vocab_file=onnxruntime_extensions.get_test_data_file(
|
||||
'../test', 'data', 'bert_basic_cased_vocab.txt'))
|
||||
onnx_model = onnx.load_model(str(model_path))
|
||||
|
||||
# create the final onnx model which includes pre- and post- processing.
|
||||
augmented_model = pnp.export(pnp.SequentialProcessingModule(
|
||||
ort_tok, mapping_token_output,
|
||||
onnx_model, post_processing_forward),
|
||||
test_sentence,
|
||||
opset_version=12,
|
||||
output_path='bert_tok_all.onnx')
|
||||
|
||||
# test the augmented onnx model with raw string input.
|
||||
model_func = OrtPyFunction.from_model('bert_tok_all.onnx')
|
||||
result = model_func(test_sentence)
|
||||
print(result)
|
Загрузка…
Ссылка в новой задаче