An end-to-end BERT model with pre-/post- processing. (#224)

* bert demo

* add some comments

* support multiple outputs in ONNX model

* code polishing

* encoding issue on Windows platform.
This commit is contained in:
Wenbing Li 2022-04-20 16:14:46 -07:00 коммит произвёл GitHub
Родитель fb378b72a0
Коммит bfbfa5a304
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 144 добавлений и 51 удалений

Просмотреть файл

@ -32,4 +32,4 @@ PyOp = PyCustomOpDef
def get_test_data_file(*sub_dirs):
case_file = inspect.currentframe().f_back.f_code.co_filename
test_dir = pathlib.Path(case_file).parent
return str(test_dir.joinpath(*sub_dirs))
return str(test_dir.joinpath(*sub_dirs).resolve())

Просмотреть файл

@ -80,7 +80,7 @@ class OrtPyFunction:
def _bind(self, oxml):
self.inputs = list(oxml.graph.input)
self.output = list(oxml.graph.output)
self.outputs = list(oxml.graph.output)
self._oxml = oxml
return self

Просмотреть файл

@ -10,4 +10,4 @@ from ._torchext import * # noqa
from ._unifier import export
from ._imagenet import * # noqa
from ._nlp import PreHuggingFaceGPT2
from ._nlp import PreHuggingFaceGPT2, PreHuggingFaceBert # noqa

Просмотреть файл

@ -78,12 +78,16 @@ class _ProcessingModule:
class ProcessingTracedModule(torch.nn.Module, _ProcessingModule):
pass
def __init__(self, func_obj=None):
super().__init__()
self.func_obj = func_obj
def forward(self, *args):
assert self.func_obj is not None, "No forward method found."
return self.func_obj(*args)
class ProcessingScriptModule(torch.nn.Module, _ProcessingModule):
def __init__(self):
super(ProcessingScriptModule, self).__init__()
@torch.jit.unused
def export(self, *args, **kwargs):

Просмотреть файл

@ -12,6 +12,25 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_
op_version=1, name=op_name, op_domain=default_opset_domain(), **kwargs)
@schema(inputs=((_dt.STRING, []),),
outputs=((_dt.INT64, []), (_dt.INT64, []), (_dt.INT64, [])))
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
if 'hf_tok' in kwargs:
# TODO: need bert-tokenizer support JSON format
hf_bert_tokenizer = kwargs['hf_tok']
attrs = {'vocab_file': json.dumps(hf_bert_tokenizer.vocab, separators=(',', ':'))}
elif 'vocab_file' in kwargs:
attrs = dict(vocab_file=kwargs['vocab_file'])
else:
raise RuntimeError("Need hf_tok/vocab_file parameter to build the tokenizer")
if 'strip_accents' in kwargs:
strip_accents = kwargs['strip_accents']
attrs['strip_accents'] = strip_accents
return make_custom_op(ctx, 'BertTokenizer', input_names,
output_names, container, operator_name=operator_name, **attrs)
@schema(inputs=((_dt.STRING, []),),
outputs=((_dt.INT64, []), (_dt.INT64, [])))
def gpt2_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
@ -44,6 +63,30 @@ def _get_bound_object(func):
return func.__self__
class PreHuggingFaceBert(ProcessingTracedModule):
def __init__(self, hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1):
super(PreHuggingFaceBert, self).__init__()
if hf_tok is None:
_vocab = None
with open(vocab_file, "r", encoding='utf-8') as vf:
lines = vf.readlines()
_vocab = '\n'.join(lines)
if _vocab is None:
raise RuntimeError("Cannot load vocabulary file {}!".format(vocab_file))
self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize,
vocab_file=_vocab,
do_lower_case=do_lower_case,
strip_accents=strip_accents)
else:
self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, hf_tok=self.hf_tok)
def forward(self, text):
return self.onnx_bert_tokenize(text)
def export(self, *args, **kwargs):
return _get_bound_object(self.onnx_bert_tokenize).build_model(kwargs.get('opset_version', 0), *args)
class PreHuggingFaceGPT2(ProcessingTracedModule):
def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1):
super(PreHuggingFaceGPT2, self).__init__()

Просмотреть файл

@ -8,7 +8,7 @@ from distutils.version import LooseVersion
from torch.onnx import register_custom_op_symbolic
from ._utils import ONNXModelUtils
from ._base import CustomFunction, ProcessingTracedModule
from ._base import CustomFunction, ProcessingTracedModule, is_processing_module
from ._onnx_ops import ox as _ox, schema as _schema
from ._onnx_ops import ONNXElementContainer, make_model_ex
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
@ -120,42 +120,9 @@ onnx_where = create_op_function('Where', _ox.where)
onnx_greater = create_op_function('Greater', _ox.greater)
class PythonOpFunction:
"""
PythonOpFunction wraps a generic Python function which skips forward operation on torch.onnx.exporter
converting in the script mode, since the exporter cannot support the APIs from external package, like Numpy.
BTW, Autograd.Function cannot be used torch.jit.script.
"""
id_func_map = {}
current_func_id = 0
@staticmethod
def _get_next_id():
PythonOpFunction.current_func_id += 1
return PythonOpFunction.current_func_id
@staticmethod
@torch.jit.ignore
def _pass_through_call(*args, **kwargs):
func_id = args[0]
func = PythonOpFunction.id_func_map[func_id]
return torch.from_numpy(func.forward(*args[1:], **kwargs))
@classmethod
def apply(cls, *args, **kwargs):
return PythonOpFunction._pass_through_call(cls.get_id(), *args, **kwargs)
@classmethod
def get_id(cls):
if not hasattr(cls, 'func_id'):
_id = PythonOpFunction._get_next_id()
setattr(cls, 'func_id', _id)
PythonOpFunction.id_func_map[_id] = cls
return cls.func_id
class _OnnxModelFunction:
id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
id_function_map = {}
str_model_function_id = '_model_function_id'
str_model_id = '_model_id'
str_model_attached = '_model_attached'
@ -163,12 +130,16 @@ class _OnnxModelFunction:
@torch.jit.ignore
def _invoke_onnx_model(model_id: int, *args, **kwargs):
func = _OnnxModelFunction.id_function_map.get(model_id, None)
if not func:
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
if model_or_path is None:
raise ValueError("cannot find id={} registered!".format(model_id))
func = OrtPyFunction.from_model(model_or_path)
return torch.from_numpy(
func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs))
_OnnxModelFunction.id_function_map[model_id] = func
results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)
return tuple(
[torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results)
@torch.jit.ignore
@ -193,7 +164,20 @@ class _OnnxTracedFunction(CustomFunction):
@classmethod
def symbolic(cls, g, *args):
return g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
ret = g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
model_id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]) # noqa
if not model_id:
return ret
func = _OnnxModelFunction.id_function_map.get(model_id.item(), None)
if not func or len(func.outputs) <= 1:
return ret
outputs = [ret]
for _ in range(len(func.outputs) - 1):
outputs.append(ret.node().addOutput())
return tuple(outputs)
def create_model_function(model_or_path):
@ -242,8 +226,11 @@ class SequentialProcessingModule(ProcessingTracedModule):
for mdl_ in models:
if isinstance(mdl_, onnx.ModelProto):
self.model_list.append(_OnnxModelModule(mdl_))
else:
elif is_processing_module(mdl_):
self.model_list.append(mdl_)
else:
assert callable(mdl_), "the model type is not recognizable."
self.model_list.append(ProcessingTracedModule(mdl_))
def forward(self, *args):
outputs = args
@ -251,7 +238,7 @@ class SequentialProcessingModule(ProcessingTracedModule):
for idx_, mdl_ in enumerate(self.model_list):
if not isinstance(outputs, tuple):
outputs = (outputs,)
outputs = self.model_list[idx_](*outputs)
outputs = mdl_(*outputs)
return outputs

Просмотреть файл

@ -133,6 +133,8 @@ class ONNXModelUtils:
nest_model,
io_mapping)
for idx_, out_ in enumerate(nest_model.graph.output):
if idx_ >= len(_node.output):
continue
_renamed_out = "{}_{}".format(prefix, out_.name)
_nd = onnx.helper.make_node('Identity',
[_renamed_out],

Просмотреть файл

@ -104,7 +104,10 @@ class BuildCMakeExt(_build_ext):
self.spawn(['cmake', '--build', str(build_temp)] + build_args)
if sys.platform == "win32":
self.copy_file(build_temp / config / 'ortcustomops.dll',
config_dir = '.'
if not (build_temp / 'build.ninja').exists():
config_dir = config
self.copy_file(build_temp / config_dir / 'ortcustomops.dll',
self.get_ext_filename(extension.name))

54
tutorials/bert_e2e.py Normal file
Просмотреть файл

@ -0,0 +1,54 @@
import onnx
import torch
import onnxruntime_extensions
from pathlib import Path
from onnxruntime_extensions import pnp, OrtPyFunction
from transformers import AutoTokenizer
from transformers.onnx import export, FeaturesManager
# get an onnx model by converting HuggingFace pretrained model
model_name = "bert-base-cased"
model_path = Path("onnx-model/bert-base-cased.onnx")
if not model_path.exists():
if not model_path.parent.exists():
model_path.parent.mkdir(parents=True, exist_ok=True)
model = FeaturesManager.get_model_from_feature("default", model_name)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
onnx_config = model_onnx_config(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
export(tokenizer,
model=model,
config=onnx_config,
opset=12,
output=model_path)
# a silly post-processing example function, demo-purpose only
def post_processing_forward(*pred):
return torch.softmax(pred[1], dim=1)
# mapping the BertTokenizer outputs into the onnx model inputs
def mapping_token_output(_1, _2, _3):
return _1.unsqueeze(0), _3.unsqueeze(0), _2.unsqueeze(0)
test_sentence = ["this is a test sentence."]
ort_tok = pnp.PreHuggingFaceBert(
vocab_file=onnxruntime_extensions.get_test_data_file(
'../test', 'data', 'bert_basic_cased_vocab.txt'))
onnx_model = onnx.load_model(str(model_path))
# create the final onnx model which includes pre- and post- processing.
augmented_model = pnp.export(pnp.SequentialProcessingModule(
ort_tok, mapping_token_output,
onnx_model, post_processing_forward),
test_sentence,
opset_version=12,
output_path='bert_tok_all.onnx')
# test the augmented onnx model with raw string input.
model_func = OrtPyFunction.from_model('bert_tok_all.onnx')
result = model_func(test_sentence)
print(result)