From bfbfa5a3044ec8d1312f3782c78ea3b9246bf667 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Wed, 20 Apr 2022 16:14:46 -0700 Subject: [PATCH] An end-to-end BERT model with pre-/post- processing. (#224) * bert demo * add some comments * support multiple outputs in ONNX model * code polishing * encoding issue on Windows platform. --- onnxruntime_extensions/__init__.py | 2 +- onnxruntime_extensions/_ortapi2.py | 2 +- onnxruntime_extensions/pnp/__init__.py | 2 +- onnxruntime_extensions/pnp/_base.py | 10 +++- onnxruntime_extensions/pnp/_nlp.py | 43 ++++++++++++++ onnxruntime_extensions/pnp/_torchext.py | 75 ++++++++++--------------- onnxruntime_extensions/pnp/_utils.py | 2 + setup.py | 5 +- tutorials/bert_e2e.py | 54 ++++++++++++++++++ 9 files changed, 144 insertions(+), 51 deletions(-) create mode 100644 tutorials/bert_e2e.py diff --git a/onnxruntime_extensions/__init__.py b/onnxruntime_extensions/__init__.py index 90a36ecb..135f4a55 100644 --- a/onnxruntime_extensions/__init__.py +++ b/onnxruntime_extensions/__init__.py @@ -32,4 +32,4 @@ PyOp = PyCustomOpDef def get_test_data_file(*sub_dirs): case_file = inspect.currentframe().f_back.f_code.co_filename test_dir = pathlib.Path(case_file).parent - return str(test_dir.joinpath(*sub_dirs)) + return str(test_dir.joinpath(*sub_dirs).resolve()) diff --git a/onnxruntime_extensions/_ortapi2.py b/onnxruntime_extensions/_ortapi2.py index 07e5003e..a234341f 100644 --- a/onnxruntime_extensions/_ortapi2.py +++ b/onnxruntime_extensions/_ortapi2.py @@ -80,7 +80,7 @@ class OrtPyFunction: def _bind(self, oxml): self.inputs = list(oxml.graph.input) - self.output = list(oxml.graph.output) + self.outputs = list(oxml.graph.output) self._oxml = oxml return self diff --git a/onnxruntime_extensions/pnp/__init__.py b/onnxruntime_extensions/pnp/__init__.py index f45dc8a2..94076ba6 100644 --- a/onnxruntime_extensions/pnp/__init__.py +++ b/onnxruntime_extensions/pnp/__init__.py @@ -10,4 +10,4 @@ from ._torchext import * # noqa from ._unifier import export from ._imagenet import * # noqa -from ._nlp import PreHuggingFaceGPT2 +from ._nlp import PreHuggingFaceGPT2, PreHuggingFaceBert # noqa diff --git a/onnxruntime_extensions/pnp/_base.py b/onnxruntime_extensions/pnp/_base.py index f14b9124..00f986c8 100644 --- a/onnxruntime_extensions/pnp/_base.py +++ b/onnxruntime_extensions/pnp/_base.py @@ -78,12 +78,16 @@ class _ProcessingModule: class ProcessingTracedModule(torch.nn.Module, _ProcessingModule): - pass + def __init__(self, func_obj=None): + super().__init__() + self.func_obj = func_obj + + def forward(self, *args): + assert self.func_obj is not None, "No forward method found." + return self.func_obj(*args) class ProcessingScriptModule(torch.nn.Module, _ProcessingModule): - def __init__(self): - super(ProcessingScriptModule, self).__init__() @torch.jit.unused def export(self, *args, **kwargs): diff --git a/onnxruntime_extensions/pnp/_nlp.py b/onnxruntime_extensions/pnp/_nlp.py index 70055963..f8f89fa0 100644 --- a/onnxruntime_extensions/pnp/_nlp.py +++ b/onnxruntime_extensions/pnp/_nlp.py @@ -12,6 +12,25 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_ op_version=1, name=op_name, op_domain=default_opset_domain(), **kwargs) +@schema(inputs=((_dt.STRING, []),), + outputs=((_dt.INT64, []), (_dt.INT64, []), (_dt.INT64, []))) +def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs): + if 'hf_tok' in kwargs: + # TODO: need bert-tokenizer support JSON format + hf_bert_tokenizer = kwargs['hf_tok'] + attrs = {'vocab_file': json.dumps(hf_bert_tokenizer.vocab, separators=(',', ':'))} + elif 'vocab_file' in kwargs: + attrs = dict(vocab_file=kwargs['vocab_file']) + else: + raise RuntimeError("Need hf_tok/vocab_file parameter to build the tokenizer") + if 'strip_accents' in kwargs: + strip_accents = kwargs['strip_accents'] + attrs['strip_accents'] = strip_accents + + return make_custom_op(ctx, 'BertTokenizer', input_names, + output_names, container, operator_name=operator_name, **attrs) + + @schema(inputs=((_dt.STRING, []),), outputs=((_dt.INT64, []), (_dt.INT64, []))) def gpt2_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs): @@ -44,6 +63,30 @@ def _get_bound_object(func): return func.__self__ +class PreHuggingFaceBert(ProcessingTracedModule): + def __init__(self, hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1): + super(PreHuggingFaceBert, self).__init__() + if hf_tok is None: + _vocab = None + with open(vocab_file, "r", encoding='utf-8') as vf: + lines = vf.readlines() + _vocab = '\n'.join(lines) + if _vocab is None: + raise RuntimeError("Cannot load vocabulary file {}!".format(vocab_file)) + self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, + vocab_file=_vocab, + do_lower_case=do_lower_case, + strip_accents=strip_accents) + else: + self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, hf_tok=self.hf_tok) + + def forward(self, text): + return self.onnx_bert_tokenize(text) + + def export(self, *args, **kwargs): + return _get_bound_object(self.onnx_bert_tokenize).build_model(kwargs.get('opset_version', 0), *args) + + class PreHuggingFaceGPT2(ProcessingTracedModule): def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1): super(PreHuggingFaceGPT2, self).__init__() diff --git a/onnxruntime_extensions/pnp/_torchext.py b/onnxruntime_extensions/pnp/_torchext.py index bfc059e5..cad2608b 100644 --- a/onnxruntime_extensions/pnp/_torchext.py +++ b/onnxruntime_extensions/pnp/_torchext.py @@ -8,7 +8,7 @@ from distutils.version import LooseVersion from torch.onnx import register_custom_op_symbolic from ._utils import ONNXModelUtils -from ._base import CustomFunction, ProcessingTracedModule +from ._base import CustomFunction, ProcessingTracedModule, is_processing_module from ._onnx_ops import ox as _ox, schema as _schema from ._onnx_ops import ONNXElementContainer, make_model_ex from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort @@ -120,42 +120,9 @@ onnx_where = create_op_function('Where', _ox.where) onnx_greater = create_op_function('Greater', _ox.greater) -class PythonOpFunction: - """ - PythonOpFunction wraps a generic Python function which skips forward operation on torch.onnx.exporter - converting in the script mode, since the exporter cannot support the APIs from external package, like Numpy. - BTW, Autograd.Function cannot be used torch.jit.script. - """ - id_func_map = {} - current_func_id = 0 - - @staticmethod - def _get_next_id(): - PythonOpFunction.current_func_id += 1 - return PythonOpFunction.current_func_id - - @staticmethod - @torch.jit.ignore - def _pass_through_call(*args, **kwargs): - func_id = args[0] - func = PythonOpFunction.id_func_map[func_id] - return torch.from_numpy(func.forward(*args[1:], **kwargs)) - - @classmethod - def apply(cls, *args, **kwargs): - return PythonOpFunction._pass_through_call(cls.get_id(), *args, **kwargs) - - @classmethod - def get_id(cls): - if not hasattr(cls, 'func_id'): - _id = PythonOpFunction._get_next_id() - setattr(cls, 'func_id', _id) - PythonOpFunction.id_func_map[_id] = cls - return cls.func_id - - class _OnnxModelFunction: id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type + id_function_map = {} str_model_function_id = '_model_function_id' str_model_id = '_model_id' str_model_attached = '_model_attached' @@ -163,12 +130,16 @@ class _OnnxModelFunction: @torch.jit.ignore def _invoke_onnx_model(model_id: int, *args, **kwargs): - model_or_path = _OnnxModelFunction.id_object_map.get(model_id) - if model_or_path is None: - raise ValueError("cannot find id={} registered!".format(model_id)) - func = OrtPyFunction.from_model(model_or_path) - return torch.from_numpy( - func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)) + func = _OnnxModelFunction.id_function_map.get(model_id, None) + if not func: + model_or_path = _OnnxModelFunction.id_object_map.get(model_id) + if model_or_path is None: + raise ValueError("cannot find id={} registered!".format(model_id)) + func = OrtPyFunction.from_model(model_or_path) + _OnnxModelFunction.id_function_map[model_id] = func + results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs) + return tuple( + [torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results) @torch.jit.ignore @@ -193,7 +164,20 @@ class _OnnxTracedFunction(CustomFunction): @classmethod def symbolic(cls, g, *args): - return g.op('ai.onnx.contrib::_ModelFunctionCall', *args) + ret = g.op('ai.onnx.contrib::_ModelFunctionCall', *args) + model_id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]) # noqa + if not model_id: + return ret + + func = _OnnxModelFunction.id_function_map.get(model_id.item(), None) + if not func or len(func.outputs) <= 1: + return ret + + outputs = [ret] + for _ in range(len(func.outputs) - 1): + outputs.append(ret.node().addOutput()) + + return tuple(outputs) def create_model_function(model_or_path): @@ -242,8 +226,11 @@ class SequentialProcessingModule(ProcessingTracedModule): for mdl_ in models: if isinstance(mdl_, onnx.ModelProto): self.model_list.append(_OnnxModelModule(mdl_)) - else: + elif is_processing_module(mdl_): self.model_list.append(mdl_) + else: + assert callable(mdl_), "the model type is not recognizable." + self.model_list.append(ProcessingTracedModule(mdl_)) def forward(self, *args): outputs = args @@ -251,7 +238,7 @@ class SequentialProcessingModule(ProcessingTracedModule): for idx_, mdl_ in enumerate(self.model_list): if not isinstance(outputs, tuple): outputs = (outputs,) - outputs = self.model_list[idx_](*outputs) + outputs = mdl_(*outputs) return outputs diff --git a/onnxruntime_extensions/pnp/_utils.py b/onnxruntime_extensions/pnp/_utils.py index 89810617..31e2f140 100644 --- a/onnxruntime_extensions/pnp/_utils.py +++ b/onnxruntime_extensions/pnp/_utils.py @@ -133,6 +133,8 @@ class ONNXModelUtils: nest_model, io_mapping) for idx_, out_ in enumerate(nest_model.graph.output): + if idx_ >= len(_node.output): + continue _renamed_out = "{}_{}".format(prefix, out_.name) _nd = onnx.helper.make_node('Identity', [_renamed_out], diff --git a/setup.py b/setup.py index 6bf34dd1..a11802e2 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,10 @@ class BuildCMakeExt(_build_ext): self.spawn(['cmake', '--build', str(build_temp)] + build_args) if sys.platform == "win32": - self.copy_file(build_temp / config / 'ortcustomops.dll', + config_dir = '.' + if not (build_temp / 'build.ninja').exists(): + config_dir = config + self.copy_file(build_temp / config_dir / 'ortcustomops.dll', self.get_ext_filename(extension.name)) diff --git a/tutorials/bert_e2e.py b/tutorials/bert_e2e.py new file mode 100644 index 00000000..0e66f34e --- /dev/null +++ b/tutorials/bert_e2e.py @@ -0,0 +1,54 @@ +import onnx +import torch +import onnxruntime_extensions + +from pathlib import Path +from onnxruntime_extensions import pnp, OrtPyFunction +from transformers import AutoTokenizer +from transformers.onnx import export, FeaturesManager + +# get an onnx model by converting HuggingFace pretrained model +model_name = "bert-base-cased" +model_path = Path("onnx-model/bert-base-cased.onnx") +if not model_path.exists(): + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + model = FeaturesManager.get_model_from_feature("default", model_name) + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default") + onnx_config = model_onnx_config(model.config) + tokenizer = AutoTokenizer.from_pretrained(model_name) + export(tokenizer, + model=model, + config=onnx_config, + opset=12, + output=model_path) + + +# a silly post-processing example function, demo-purpose only +def post_processing_forward(*pred): + return torch.softmax(pred[1], dim=1) + + +# mapping the BertTokenizer outputs into the onnx model inputs +def mapping_token_output(_1, _2, _3): + return _1.unsqueeze(0), _3.unsqueeze(0), _2.unsqueeze(0) + + +test_sentence = ["this is a test sentence."] +ort_tok = pnp.PreHuggingFaceBert( + vocab_file=onnxruntime_extensions.get_test_data_file( + '../test', 'data', 'bert_basic_cased_vocab.txt')) +onnx_model = onnx.load_model(str(model_path)) + +# create the final onnx model which includes pre- and post- processing. +augmented_model = pnp.export(pnp.SequentialProcessingModule( + ort_tok, mapping_token_output, + onnx_model, post_processing_forward), + test_sentence, + opset_version=12, + output_path='bert_tok_all.onnx') + +# test the augmented onnx model with raw string input. +model_func = OrtPyFunction.from_model('bert_tok_all.onnx') +result = model_func(test_sentence) +print(result)