support mobilebert_ppp (#354)
* support mobilebert_ppp * renaming IOEntryValuePreserver * generalize argmax step --------- Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
This commit is contained in:
Родитель
ef3df607dd
Коммит
b375cb57e6
|
@ -7,7 +7,7 @@ import onnx
|
|||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Union
|
||||
# NOTE: If you're working on this script install onnxruntime_extensions using `pip install -e .` from the repo root
|
||||
# and run with `python -m onnxruntime_extensions.tools.add_pre_post_processing_to_model`
|
||||
# Running directly will result in an error from a relative import.
|
||||
|
@ -162,21 +162,110 @@ def superresolution(model_file: Path, output_file: Path, output_format: str, onn
|
|||
onnx.save_model(new_model, str(output_file.resolve()))
|
||||
|
||||
|
||||
class NLPTaskType(enum.Enum):
|
||||
TokenClassification = enum.auto()
|
||||
QuestionAnswering = enum.auto()
|
||||
SequenceClassification = enum.auto()
|
||||
NextSentencePrediction = enum.auto()
|
||||
|
||||
|
||||
class TokenizerType(enum.Enum):
|
||||
BertTokenizer = enum.auto()
|
||||
SentencePieceTokenizer = enum.auto()
|
||||
|
||||
|
||||
def transformers_and_bert(
|
||||
input_model_file: Path,
|
||||
output_model_file: Path,
|
||||
vocab_file: Path,
|
||||
tokenizer_type: Union[TokenizerType, str],
|
||||
task_type: Union[NLPTaskType, str],
|
||||
onnx_opset: int = 16,
|
||||
add_debug_before_postprocessing=False,
|
||||
):
|
||||
"""construct the pipeline for a end2end model with pre and post processing. The final model can take text as inputs
|
||||
and output the result in text format for model like QA.
|
||||
|
||||
Args:
|
||||
input_model_file (Path): the model file needed to be updated.
|
||||
output_model_file (Path): where to save the final onnx model.
|
||||
vocab_file (Path): the vocab file for the tokenizer.
|
||||
task_type (Union[NLPTaskType, str]): the task type of the model.
|
||||
onnx_opset (int, optional): the opset version to use. Defaults to 16.
|
||||
add_debug_before_postprocessing (bool, optional): whether to add a debug step before post processing.
|
||||
Defaults to False.
|
||||
"""
|
||||
if isinstance(task_type, str):
|
||||
task_type = NLPTaskType[task_type]
|
||||
if isinstance(tokenizer_type, str):
|
||||
tokenizer_type = TokenizerType[tokenizer_type]
|
||||
|
||||
onnx_model = onnx.load(str(input_model_file.resolve(strict=True)))
|
||||
# hardcode batch size to 1
|
||||
inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])]
|
||||
|
||||
pipeline = PrePostProcessor(inputs, onnx_opset)
|
||||
tokenizer_args = TokenizerParam(
|
||||
vocab_or_file=vocab_file,
|
||||
do_lower_case=True,
|
||||
tweaked_bos_id=0,
|
||||
is_sentence_pair=True if task_type in [NLPTaskType.QuestionAnswering,
|
||||
NLPTaskType.NextSentencePrediction] else False,
|
||||
)
|
||||
|
||||
preprocessing = [
|
||||
SentencePieceTokenizer(tokenizer_args)
|
||||
if tokenizer_type == TokenizerType.SentencePieceTokenizer else BertTokenizer(tokenizer_args),
|
||||
# uncomment this line to debug
|
||||
# Debug(2),
|
||||
]
|
||||
|
||||
# For verify results with out postprocessing
|
||||
postprocessing = [Debug()] if add_debug_before_postprocessing else []
|
||||
if task_type == NLPTaskType.QuestionAnswering:
|
||||
postprocessing.append((BertTokenizerQADecoder(tokenizer_args), [
|
||||
# input_ids
|
||||
utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)]))
|
||||
elif task_type == NLPTaskType.SequenceClassification:
|
||||
postprocessing.append(ArgMax())
|
||||
# the other tasks don't need postprocessing or we don't support it yet.
|
||||
|
||||
pipeline.add_pre_processing(preprocessing)
|
||||
pipeline.add_post_processing(postprocessing)
|
||||
|
||||
new_model = pipeline.run(onnx_model)
|
||||
onnx.save_model(new_model, str(output_model_file.resolve()))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
description="""Add pre and post processing to a model.
|
||||
|
||||
Currently supports updating:
|
||||
- super resolution with YCbCr input
|
||||
- imagenet trained mobilenet
|
||||
Vision models:
|
||||
- super resolution with YCbCr input
|
||||
- imagenet trained mobilenet
|
||||
NLP models:
|
||||
|
||||
- MobileBert with different tasks
|
||||
- XLM-Roberta with classification task
|
||||
|
||||
To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide.
|
||||
For Vision models:
|
||||
To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide.
|
||||
Create a pipeline and add the required pre/post processing 'Steps' in the order required. Configure
|
||||
individual steps as needed.
|
||||
individual steps as needed.
|
||||
|
||||
For NLP models:
|
||||
`transformers_and_bert` can be used for MobileBert QuestionAnswering/Classification tasks,
|
||||
or serve as a guide of how to add pre/post processing to a transformer model.
|
||||
Usually pre-processing includes adding a tokenizer. Post-processing includes conversion of output_ids to text.
|
||||
|
||||
You might need to pass the tokenizer model file (bert vocab file or SentencePieceTokenizer model)
|
||||
and task_type to the function.
|
||||
|
||||
The updated model will be written in the same location as the original model, with '.onnx' updated to
|
||||
'.with_pre_post_processing.onnx'
|
||||
The updated model will be written in the same location as the original model,
|
||||
with '.onnx' updated to '.with_pre_post_processing.onnx'
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -185,7 +274,11 @@ def main():
|
|||
"--model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["superresolution", "mobilenet"],
|
||||
choices=[
|
||||
"superresolution",
|
||||
"mobilenet",
|
||||
"transformers",
|
||||
],
|
||||
help="Model type.",
|
||||
)
|
||||
|
||||
|
@ -212,9 +305,35 @@ def main():
|
|||
help="Image output format for superresolution model to produce.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nlp_task_type",
|
||||
type=str,
|
||||
choices=["QuestionAnswering",
|
||||
"SequenceClassification",
|
||||
"NextSentencePrediction"],
|
||||
required=False,
|
||||
help="The downstream task for NLP model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab_file",
|
||||
type=Path,
|
||||
required=False,
|
||||
help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_type",
|
||||
type=str,
|
||||
choices=["BertTokenizer",
|
||||
"SentencePieceTokenizer"],
|
||||
required=False,
|
||||
help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opset", type=int, required=False, default=16,
|
||||
help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing."
|
||||
help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing.",
|
||||
)
|
||||
|
||||
parser.add_argument("model", type=Path, help="Provide path to ONNX model to update.")
|
||||
|
@ -227,8 +346,13 @@ def main():
|
|||
if args.model_type == "mobilenet":
|
||||
source = ModelSource.PYTORCH if args.model_source == "pytorch" else ModelSource.TENSORFLOW
|
||||
mobilenet(model_path, new_model_path, source, args.opset)
|
||||
elif args.model_type == "superresolution":
|
||||
superresolution(model_path, new_model_path,
|
||||
args.output_format, args.opset)
|
||||
else:
|
||||
superresolution(model_path, new_model_path, args.output_format, args.opset)
|
||||
if args.vocab_file is None or args.nlp_task_type is None or args.tokenizer_type is None:
|
||||
parser.error("Please provide vocab file/nlp_task_type/tokenizer_type.")
|
||||
transformers_and_bert(model_path, new_model_path, args.tokenizer_type, args.vocab_file, args.nlp_task_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,10 +8,9 @@ Classes
|
|||
: Step that can be arbitrarily inserted in the pre or post processing pipeline.
|
||||
It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
|
||||
|
||||
NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it
|
||||
may or may not have '_debug' as a suffix.
|
||||
TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node
|
||||
to rename so it's more consistent.
|
||||
The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
|
||||
another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
|
||||
the "_debug" outputs will become graph outputs.
|
||||
|
||||
Initialize Debug step
|
||||
Args:
|
||||
|
@ -35,11 +34,15 @@ Classes
|
|||
### Descendants
|
||||
|
||||
* pre_post_processing.step.Debug
|
||||
* pre_post_processing.steps.general.ArgMax
|
||||
* pre_post_processing.steps.general.ReverseAxis
|
||||
* pre_post_processing.steps.general.Softmax
|
||||
* pre_post_processing.steps.general.Squeeze
|
||||
* pre_post_processing.steps.general.Transpose
|
||||
* pre_post_processing.steps.general.Unsqueeze
|
||||
* pre_post_processing.steps.nlp.BertTokenizer
|
||||
* pre_post_processing.steps.nlp.BertTokenizerQADecoder
|
||||
* pre_post_processing.steps.nlp.SentencePieceTokenizer
|
||||
* pre_post_processing.steps.vision.CenterCrop
|
||||
* pre_post_processing.steps.vision.ConvertBGRToImage
|
||||
* pre_post_processing.steps.vision.ConvertImageToBGR
|
||||
|
@ -57,9 +60,12 @@ Classes
|
|||
|
||||
### Methods
|
||||
|
||||
`apply(self, graph: onnx.onnx_ml_pb2.GraphProto, checker_context: onnx.onnx_cpp2py_export.checker.CheckerContext)`
|
||||
`apply(self, graph: onnx.onnx_ml_pb2.GraphProto, checker_context: onnx.onnx_cpp2py_export.checker.CheckerContext, graph_outputs_to_maintain: List[str])`
|
||||
: Create a graph for this step that can be appended to the provided graph.
|
||||
The PrePostProcessor will handle merging the two.
|
||||
graph_outputs_to_maintain: List of output names to maintain in the graph.
|
||||
Some outputs might be consumed during graph merging, so we have to explicitly maintain it for subsequent steps.
|
||||
This outputs is generated by IOEntryValuePreserver.
|
||||
|
||||
`connect(self, entry: pre_post_processing.utils.IoMapEntry)`
|
||||
: Connect the value name from a previous step to an input of this step so they match.
|
||||
|
|
|
@ -4,6 +4,18 @@ Module pre_post_processing.steps.general
|
|||
Classes
|
||||
-------
|
||||
|
||||
`ArgMax(name: Optional[str] = None, axis: int = -1, keepdims: int = 0)`
|
||||
: Base class for a pre or post processing step.
|
||||
|
||||
Brief:
|
||||
Same as ArgMax op.
|
||||
Args:
|
||||
name: Optional name of step. Defaults to 'ArgMax'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`ReverseAxis(axis: int = -1, dim_value: int = -1, name: Optional[str] = None)`
|
||||
: Reverses the data in an axis by splitting and concatenating in reverse order.
|
||||
e.g. convert RGB ordered data to BGR.
|
||||
|
|
|
@ -4,4 +4,5 @@ Module pre_post_processing.steps
|
|||
Sub-modules
|
||||
-----------
|
||||
* pre_post_processing.steps.general
|
||||
* pre_post_processing.steps.nlp
|
||||
* pre_post_processing.steps.vision
|
|
@ -0,0 +1,68 @@
|
|||
Module pre_post_processing.steps.nlp
|
||||
====================================
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`BertTokenizer(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, name: Optional[str] = None)`
|
||||
: Base class for a pre or post processing step.
|
||||
|
||||
Brief: This step is used to convert the input text into the input_ids, attention_mask, token_type_ids.
|
||||
It supports an input of a single string for classification models, or two strings for QA models.
|
||||
Args:
|
||||
tokenizer_param: some essential infos to build a tokenizer,
|
||||
You can create a TokenizerParam like this:
|
||||
tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, # vocab is dict or file_path
|
||||
strip_accents = True or False (Optional),
|
||||
do_lower_case = True or False (Optional)
|
||||
)
|
||||
|
||||
name: Optional name of step. Defaults to 'BertTokenizer'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`BertTokenizerQADecoder(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, name: Optional[str] = None)`
|
||||
: Base class for a pre or post processing step.
|
||||
|
||||
Brief:
|
||||
Decode the input_ids to text
|
||||
Args:
|
||||
tokenizer_param: some essential info to build a tokenizer.
|
||||
you can create a TokenizerParam object like:
|
||||
tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, #vocab is dict or file_path)
|
||||
name: Optional name of step. Defaults to 'BertTokenizerQADecoder'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`SentencePieceTokenizer(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, nbest_size=0, alpha=1.0, reverse=False, add_bos=False, add_eos=False, name: Optional[str] = None)`
|
||||
: Base class for a pre or post processing step.
|
||||
|
||||
Brief:
|
||||
SentencePieceTokenizer has actually 6 inputs in definition, but we allow user to provide only text input,
|
||||
and make the others, "nbest_size", "alpha", "add_bos", "add_eos", "reverse" optional.
|
||||
Args:
|
||||
tokenizer_param: some essential infos to build a tokenizer
|
||||
you can create a TokenizerParam object like:
|
||||
tokenizer_param = TokenizerParam(vocab_size=tokenizer.vocab_size,
|
||||
tweaked_bos_id=tokenizer.tweaked_bos_id)
|
||||
|
||||
nbest_size: int, optional (default = 0)
|
||||
alpha: float, optional (default = 1.0)
|
||||
reverse: bool, optional (default = False)
|
||||
add_bos: bool, optional (default = False)
|
||||
add_eos: bool, optional (default = False)
|
||||
Please see more detail explanation in
|
||||
https://www.tensorflow.org/text/api_docs/python/text/SentencepieceTokenizer#args
|
||||
|
||||
name: Optional name of step. Defaults to 'SentencePieceTokenizer'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`TokenizerParam(vocab_or_file: Union[pathlib.Path, dict], **kwargs)`
|
||||
:
|
|
@ -46,6 +46,38 @@ Functions
|
|||
Classes
|
||||
-------
|
||||
|
||||
`IOEntryValuePreserver(producer: Union[ForwardRef('Step'), str] = None, consumer: Union[ForwardRef('Step'), str] = None, producer_idx: int = 0, is_active: bool = False, output: str = None)`
|
||||
: used to allow an output value to have multiple consumers,
|
||||
which is only possible when IoMapEntry is used to create those additional connections.
|
||||
|
||||
Generally, a connection consumes an output and an input, then the output is removed from the graph.
|
||||
This class enabled one-to-many connections by making the other consumers share the same output.
|
||||
|
||||
How this class works:
|
||||
1. when the IoMapEntry is created, this class will be created simultaneously.
|
||||
2. It records the producer and consumer steps, and the output index of the producer step.
|
||||
when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output.
|
||||
3. when graph merge happens, this class will check if the output is still in the graph, if not,
|
||||
it will add the output
|
||||
4. when consumer step is running, this class will be deactivated and remove output from preserved_list.
|
||||
|
||||
### Class variables
|
||||
|
||||
`consumer: Union[Step, str]`
|
||||
:
|
||||
|
||||
`is_active: bool`
|
||||
:
|
||||
|
||||
`output: str`
|
||||
:
|
||||
|
||||
`producer: Union[Step, str]`
|
||||
:
|
||||
|
||||
`producer_idx: int`
|
||||
:
|
||||
|
||||
`IoMapEntry(producer: Union[ForwardRef('Step'), str] = None, producer_idx: int = 0, consumer_idx: int = 0)`
|
||||
: Entry to map the output index from a producer step to the input index of a consumer step.
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Tuple, Union
|
|||
|
||||
from .utils import (
|
||||
IoMapEntry,
|
||||
IOEntryValuePreserver,
|
||||
create_custom_op_checker_context,
|
||||
sanitize_output_names,
|
||||
TENSOR_TYPE_TO_ONNX_TYPE,
|
||||
|
@ -56,6 +57,12 @@ class PrePostProcessor:
|
|||
self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
|
||||
|
||||
self._inputs = inputs if inputs else []
|
||||
|
||||
# preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps.
|
||||
# we now can support a output value has more than one consumers with IOEntryValuePreserver.
|
||||
# IOEntryValuePreserver will preserve the output value and add it to the graph output
|
||||
# until consumer step is done.
|
||||
self.outputs_preserver = [] # type: List[IOEntryValuePreserver]
|
||||
|
||||
def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
|
||||
"""
|
||||
|
@ -140,12 +147,31 @@ class PrePostProcessor:
|
|||
n.name = prefix + str(idx)
|
||||
idx += 1
|
||||
|
||||
def preserved_apply(processor: Step, *args):
|
||||
# Trying to activate the IOEntryValuePreserver and preserve outputs.
|
||||
# and deactivate the outputs when the current graph consumed them
|
||||
|
||||
for preserver in self.outputs_preserver:
|
||||
if preserver.consumer == processor:
|
||||
preserver.is_active = False
|
||||
|
||||
# IOEntryValuePreserver, preserve those outputs which has multiple consumers.
|
||||
# we explicitly add the output to the graph output.
|
||||
graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active]
|
||||
graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain)
|
||||
|
||||
for preserver in self.outputs_preserver:
|
||||
if preserver.producer == processor:
|
||||
preserver.is_active = True
|
||||
preserver.output = processor.output_names[preserver.producer_idx]
|
||||
return graph_for_step
|
||||
|
||||
def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
|
||||
for connection in connections:
|
||||
assert connection.producer
|
||||
self._add_connection(processor, connection)
|
||||
|
||||
return processor.apply(graph, self._custom_op_checker_context)
|
||||
return preserved_apply(processor, graph, self._custom_op_checker_context)
|
||||
|
||||
# fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
|
||||
if self.post_processors:
|
||||
|
@ -173,11 +199,22 @@ class PrePostProcessor:
|
|||
self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
|
||||
|
||||
# map the pre-processing outputs to graph inputs
|
||||
# we may need a natty way to get possible outputs after merge_graphs
|
||||
step_graph_outputs = [o.name for o in pre_process_graph.output]
|
||||
io_map = [] # type: List[Tuple[str, str]]
|
||||
for step, step_idx, graph_input in self._pre_processing_joins:
|
||||
io_map.append((step.output_names[step_idx], graph_input))
|
||||
step_graph_outputs.remove((step.output_names[step_idx]))
|
||||
|
||||
graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map)
|
||||
# add outputs from previous IoMapEntry producers to maintain them as graph outputs
|
||||
# until consumed by the final Step that requires them.
|
||||
step_graph_outputs += [
|
||||
o.name for o in graph.output if o.name not in step_graph_outputs]
|
||||
external_outputs = [
|
||||
i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs]
|
||||
if external_outputs:
|
||||
step_graph_outputs.extend(external_outputs)
|
||||
graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs)
|
||||
|
||||
# add post-processing
|
||||
if self.post_processors:
|
||||
|
@ -274,6 +311,7 @@ class PrePostProcessor:
|
|||
producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
|
||||
|
||||
io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
|
||||
self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx))
|
||||
|
||||
processors.append(step)
|
||||
processor_connections.append([entry for entry in io_map_entries if entry is not None])
|
||||
|
|
|
@ -48,10 +48,18 @@ class Step(object):
|
|||
|
||||
self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
|
||||
|
||||
def apply(self, graph: onnx.GraphProto, checker_context: onnx.checker.C.CheckerContext):
|
||||
def apply(self, graph: onnx.GraphProto,
|
||||
checker_context: onnx.checker.C.CheckerContext,
|
||||
graph_outputs_to_maintain: List[str]):
|
||||
"""
|
||||
Create a graph for this step that can be appended to the provided graph.
|
||||
The PrePostProcessor will handle merging the two.
|
||||
|
||||
Args:
|
||||
graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort.
|
||||
For outputs having multiple consumers, these outputs will be consumed by default and prevent
|
||||
connection from the subsequent steps.
|
||||
This outputs is generated by IOEntryValuePreserver.
|
||||
"""
|
||||
|
||||
onnx_opset = checker_context.opset_imports[""]
|
||||
|
@ -60,7 +68,7 @@ class Step(object):
|
|||
|
||||
# prefix the graph for this step to guarantee no clashes of value names with the existing graph
|
||||
onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
|
||||
result = self.__merge(graph, graph_for_step)
|
||||
result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain)
|
||||
|
||||
# update self.output_names to the prefixed names so that when we connect later Steps the values match
|
||||
new_outputs = [self._prefix + o for o in self.output_names]
|
||||
|
@ -86,8 +94,10 @@ class Step(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto):
|
||||
def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto,
|
||||
graph_outputs_to_maintain: Optional[List[str]] = None):
|
||||
# We prefixed all the value names in `second`, so allow for that when connecting the two graphs
|
||||
first_output = [o.name for o in first.output]
|
||||
io_map = []
|
||||
for o in first.output:
|
||||
# apply the same prefix to the output from the previous step to match the prefixed graph from this step
|
||||
|
@ -95,24 +105,13 @@ class Step(object):
|
|||
for i in second.input:
|
||||
if i.name == prefixed_output:
|
||||
io_map.append((o.name, i.name))
|
||||
|
||||
outputs_to_preserve = None
|
||||
|
||||
# special handling of Debug class.
|
||||
if isinstance(self, Debug):
|
||||
# preserve outputs of the first graph so they're available downstream. otherwise they are consumed by
|
||||
# the Debug node and disappear during the ONNX graph_merge as it considers consumed values to be
|
||||
# internal - which is entirely reasonable when merging graphs.
|
||||
# the issue we have is that we don't know what future steps might want things to remain as outputs.
|
||||
# the current approach is to insert a Debug step which simply duplicates the values so that they are
|
||||
# guaranteed not be consumed (only one of the two copies will be used).
|
||||
# doesn't change the number of outputs from the previous step, so it can be transparently inserted in the
|
||||
# pre/post processing pipeline.
|
||||
# need to also list the second graph's outputs when manually specifying outputs.
|
||||
outputs_to_preserve = [o.name for o in first.output] + [o.name for o in second.output]
|
||||
|
||||
first_output.remove(o.name)
|
||||
|
||||
graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output]
|
||||
graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs]
|
||||
|
||||
# merge with existing graph
|
||||
merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=outputs_to_preserve)
|
||||
merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs)
|
||||
|
||||
return merged_graph
|
||||
|
||||
|
@ -154,16 +153,14 @@ class Step(object):
|
|||
return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
|
||||
|
||||
|
||||
# special case. we include the helper Debug step here as logic in the base class is conditional on it.
|
||||
class Debug(Step):
|
||||
"""
|
||||
Step that can be arbitrarily inserted in the pre or post processing pipeline.
|
||||
It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
|
||||
|
||||
NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it
|
||||
may or may not have '_debug' as a suffix.
|
||||
TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node
|
||||
to rename so it's more consistent.
|
||||
The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
|
||||
another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
|
||||
the "_debug" outputs will become graph outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
|
||||
|
@ -180,32 +177,43 @@ class Debug(Step):
|
|||
super().__init__(input_names, output_names, name)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
||||
input_str = ""
|
||||
output_str = ""
|
||||
output_debug_str = ""
|
||||
nodes_str = ""
|
||||
if self._num_inputs > len(graph.output):
|
||||
raise ValueError(
|
||||
f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.")
|
||||
|
||||
debug_offset = len(self.input_names)
|
||||
# update output names so we preserve info from the latest input names
|
||||
self.output_names = [f"{name}_debug" for name in self.input_names]
|
||||
self.output_names = [f"{name}_next" for name in self.input_names]
|
||||
self.output_names += [f"{name}_debug" for name in self.input_names]
|
||||
|
||||
input_str_list = []
|
||||
output_str_list = []
|
||||
nodes_str_list = []
|
||||
for i in range(0, self._num_inputs):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, i)
|
||||
if i > 0:
|
||||
input_str += ", "
|
||||
output_str += ", "
|
||||
output_debug_str += ", "
|
||||
nodes_str += "\n"
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(
|
||||
graph, i)
|
||||
|
||||
input_str += f"{input_type_str}[{input_shape_str}] {self.input_names[i]}"
|
||||
output_str += f"{input_type_str}[{input_shape_str}] {self.output_names[i]}"
|
||||
nodes_str += f"{self.output_names[i]} = Identity({self.input_names[i]})\n"
|
||||
input_str_list.append(
|
||||
f"{input_type_str}[{input_shape_str}] {self.input_names[i]}")
|
||||
|
||||
output_str_list.append(
|
||||
f"{input_type_str}[{input_shape_str}] {self.output_names[i]}")
|
||||
output_str_list.append(
|
||||
f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}")
|
||||
|
||||
nodes_str_list.append(
|
||||
f"{self.output_names[i]} = Identity({self.input_names[i]})\n")
|
||||
nodes_str_list.append(
|
||||
f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n")
|
||||
|
||||
# f-string can't have back-slash
|
||||
node_str = '\n'.join(nodes_str_list)
|
||||
debug_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
debug ({input_str})
|
||||
=> ({output_str})
|
||||
debug ({','.join(input_str_list)})
|
||||
=> ({','.join(output_str_list)})
|
||||
{{
|
||||
{nodes_str}
|
||||
{node_str}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
|
||||
from .general import *
|
||||
from .vision import *
|
||||
from .nlp import *
|
||||
|
|
|
@ -210,3 +210,45 @@ class Unsqueeze(Step):
|
|||
)
|
||||
|
||||
return unsqueeze_graph
|
||||
|
||||
|
||||
class ArgMax(Step):
|
||||
def __init__(self, name: Optional[str] = None, axis: int = -1, keepdims: int = 0):
|
||||
"""
|
||||
Brief:
|
||||
Same as ArgMax op.
|
||||
Args:
|
||||
name: Optional name of step. Defaults to 'ArgMax'
|
||||
|
||||
"""
|
||||
super().__init__(["data"], ["index"], name)
|
||||
self._axis = axis
|
||||
self._keepdims = keepdims
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
||||
input_type_str_0, input_shape_str_0 = self._get_input_type_and_shape_strs(graph, 0)
|
||||
input_shape_0 = input_shape_str_0.split(",")
|
||||
|
||||
def build_input_declare():
|
||||
return f"{input_type_str_0}[{input_shape_str_0}] {self.input_names[0]}"
|
||||
|
||||
axis = self._axis + len(input_shape_0) if self._axis < 0 else self._axis
|
||||
if axis >= len(input_shape_0):
|
||||
raise ValueError("axis should be in range [-rank, rank-1].")
|
||||
|
||||
output_shape_str = input_shape_0.copy()
|
||||
output_shape_str[axis] = "1"
|
||||
if self._keepdims == 0:
|
||||
output_shape_str.pop(axis)
|
||||
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
classify ({build_input_declare()})
|
||||
=> (int64[{','.join(output_shape_str)}] {self.output_names[0]})
|
||||
{{
|
||||
{self.output_names[0]} = ArgMax<axis = {self._axis}, keepdims={self._keepdims}>({self.input_names[0]})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return converter_graph
|
||||
|
|
|
@ -0,0 +1,337 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Optional, Union, Dict
|
||||
from ..step import Step
|
||||
|
||||
|
||||
class TokenizerParam(object):
|
||||
def __init__(self, vocab_or_file: Union[Path, dict], **kwargs):
|
||||
self.vocab_or_file = vocab_or_file
|
||||
self.tweaked_bos_id = 1
|
||||
self.strip_accents = 0
|
||||
self.do_lower_case = 0
|
||||
self.is_sentence_pair = 0
|
||||
self.__assigned_with_kwargs(**kwargs)
|
||||
|
||||
def __assigned_with_kwargs(self, **kwargs):
|
||||
for key in self.__dict__.keys():
|
||||
if key in kwargs and kwargs.get(key) is not None:
|
||||
setattr(self, key, kwargs[key])
|
||||
|
||||
|
||||
class SentencePieceTokenizer(Step):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_param: TokenizerParam,
|
||||
nbest_size=0,
|
||||
alpha=1.0,
|
||||
reverse=False,
|
||||
add_bos=False,
|
||||
add_eos=False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Brief:
|
||||
SentencePieceTokenizer has actually 6 inputs in definition, but we allow user to provide only text input,
|
||||
and make the others, "nbest_size", "alpha", "add_bos", "add_eos", "reverse" optional.
|
||||
Args:
|
||||
tokenizer_param: some essential infos to build a tokenizer
|
||||
you can create a TokenizerParam object like:
|
||||
tokenizer_param = TokenizerParam(vocab_size=tokenizer.vocab_size,
|
||||
tweaked_bos_id=tokenizer.tweaked_bos_id)
|
||||
|
||||
nbest_size: int, optional (default = 0)
|
||||
alpha: float, optional (default = 1.0)
|
||||
reverse: bool, optional (default = False)
|
||||
add_bos: bool, optional (default = False)
|
||||
add_eos: bool, optional (default = False)
|
||||
Please see more detail explanation in
|
||||
https://www.tensorflow.org/text/api_docs/python/text/SentencepieceTokenizer#args
|
||||
|
||||
name: Optional name of step. Defaults to 'SentencePieceTokenizer'
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
["input_text", "nbest_size", "alpha", "add_bos", "add_eos", "reverse"], ["input_ids", "attention_mask"], name
|
||||
)
|
||||
self._tokenizer_param = tokenizer_param
|
||||
# python bool value (True/False) is not supported in c++, so we use 0/1 to represent bool
|
||||
self._optional_kwargs = dict(
|
||||
nbest_size=nbest_size, alpha=alpha, add_bos=int(add_bos), add_eos=int(add_eos), reverse=int(reverse)
|
||||
)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
||||
# input text
|
||||
input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
|
||||
input_shape_0 = input_shape_str0.split(",")
|
||||
# ideally, we should support batch input, each batch has different length and output a token
|
||||
# !!! But, the implementation of SentencePieceTokenizer is not batch supported, inputs will be flatten to 1D
|
||||
# in the sentence-piece kernel
|
||||
assert input_type_str0 == "string"
|
||||
|
||||
# we have to do this hack here, because some models tweaked bos_id to 0, but we have still 1
|
||||
# as default value in model file.
|
||||
# it is only a temporary solution, we will remove it in the future.
|
||||
tweak_bos_id = False
|
||||
if self._tokenizer_param.tweaked_bos_id != 1 and self._optional_kwargs["add_bos"]:
|
||||
self._optional_kwargs["add_bos"] = 0
|
||||
tweak_bos_id = True
|
||||
|
||||
batch_dim = input_shape_0[0] if len(input_shape_0) > 1 else "1"
|
||||
prefix_ = f'step_{self.step_num}'
|
||||
output_shape_str = f"{batch_dim}, {prefix_}__num_ids"
|
||||
|
||||
def build_input_declare():
|
||||
input_base = f"{input_type_str0}[{input_shape_str0}] {self.input_names[0]}"
|
||||
return input_base
|
||||
|
||||
def build_call_para():
|
||||
para_base = ["input_with_batch"]
|
||||
para_base.append("i64_nbest_size")
|
||||
para_base.append("f32_alpha")
|
||||
para_base.append("bool_add_bos")
|
||||
para_base.append("bool_add_eos")
|
||||
para_base.append("bool_reverse")
|
||||
return ",".join(para_base)
|
||||
|
||||
def build_forward_declare():
|
||||
# default values for nbest_size, alpha, add_bos, add_eos, reverse
|
||||
declare_base = [
|
||||
f"i64_nbest_size = Constant <value = int64[1] {{{self._optional_kwargs['nbest_size']}}}> ()",
|
||||
f"f32_alpha = Constant <value = float[1] {{ {self._optional_kwargs['alpha']} }}> ()",
|
||||
f"bool_add_bos = Constant <value = bool[1] {{{self._optional_kwargs['add_bos']}}}> ()",
|
||||
f"bool_add_eos = Constant <value = bool[1] {{{self._optional_kwargs['add_eos']}}}> ()",
|
||||
f"bool_reverse = Constant <value = bool[1] {{{self._optional_kwargs['reverse']}}}> ()",
|
||||
]
|
||||
|
||||
return "\n".join(declare_base)
|
||||
|
||||
# TODO Camembert and XLMRoberta tokenizers has a different bos_token_id (0) from the default value (1)
|
||||
# Now, we are hacking it.
|
||||
|
||||
def hack_bos_id():
|
||||
if tweak_bos_id:
|
||||
return f'''
|
||||
k_start = Constant <value = int32[1] {{{self._tokenizer_param.tweaked_bos_id}}}> ()
|
||||
input_ids_concat02 = Concat <axis = 0> (k_start, token)
|
||||
input_ids_bdim = Unsqueeze(input_ids_concat02, i64_0)
|
||||
'''
|
||||
else:
|
||||
return '''
|
||||
input_ids_bdim = Unsqueeze(token, i64_0)
|
||||
'''
|
||||
|
||||
def build_unsqueeze():
|
||||
if len(input_shape_0) == 1:
|
||||
return f"""
|
||||
input_with_batch = Unsqueeze({self.input_names[0]}, i64_0)
|
||||
"""
|
||||
else:
|
||||
return f"""
|
||||
input_with_batch = Identity({self.input_names[0]})
|
||||
"""
|
||||
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
SentencePiecetokenizer ({build_input_declare()})
|
||||
=> (int64[{output_shape_str}] {self.output_names[0]},int64[{output_shape_str}] {self.output_names[1]})
|
||||
{{
|
||||
{build_forward_declare()}
|
||||
i64_neg1 = Constant <value = int64[1] {{-1}}> ()
|
||||
i64_0 = Constant <value = int64[1] {{0}}> ()
|
||||
{build_unsqueeze()}
|
||||
token,idx = com.microsoft.extensions.SentencepieceTokenizer ({build_call_para()})
|
||||
{hack_bos_id()}
|
||||
{self.output_names[0]} = Cast <to = 7> (input_ids_bdim)
|
||||
attention_mask_i32=Greater({self.output_names[0]}, i64_neg1)
|
||||
{self.output_names[1]} = Cast <to = 7> (attention_mask_i32)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
with open(self._tokenizer_param.vocab_or_file, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
token_model_attr = onnx.helper.make_attribute("model", content)
|
||||
node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == "SentencepieceTokenizer")
|
||||
converter_graph.node[node_idx].attribute.append(token_model_attr)
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
def _vocab_to_dict(vocab_or_file: Union[Dict[str, int], Path, str]):
|
||||
if isinstance(vocab_or_file, (Path, str)):
|
||||
# read from file
|
||||
import json
|
||||
with open(vocab_or_file, "r") as f:
|
||||
vocab = json.load(f)
|
||||
else:
|
||||
vocab = vocab_or_file
|
||||
|
||||
ordered_vocab = OrderedDict(sorted(vocab.items(), key=lambda item: int(item[1])))
|
||||
|
||||
vocab = "\n".join(ordered_vocab.keys())
|
||||
return dict(vocab_file=vocab)
|
||||
|
||||
|
||||
class BertTokenizer(Step):
|
||||
def __init__(self, tokenizer_param: TokenizerParam, name: Optional[str] = None):
|
||||
"""
|
||||
Brief: This step is used to convert the input text into the input_ids, attention_mask, token_type_ids.
|
||||
It supports an input of a single string for classification models, or two strings for QA models.
|
||||
Args:
|
||||
tokenizer_param: some essential infos to build a tokenizer,
|
||||
You can create a TokenizerParam like this:
|
||||
tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, # vocab is dict or file_path
|
||||
strip_accents = True or False (Optional),
|
||||
do_lower_case = True or False (Optional)
|
||||
)
|
||||
|
||||
name: Optional name of step. Defaults to 'BertTokenizer'
|
||||
|
||||
"""
|
||||
super().__init__(["input_text"], ["input_ids", "attention_mask", "token_type_ids"], name)
|
||||
self._tokenizer_param = tokenizer_param
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
||||
input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
|
||||
|
||||
input_shape_0 = input_shape_str0.split(",")
|
||||
prefix_ = f'step_{self.step_num}'
|
||||
# only support bath size 1 until tokenizer op supports batch size > 1
|
||||
batch_dim = input_shape_0[0] if len(input_shape_0) > 1 else "1"
|
||||
output_shape_str = f"{batch_dim}, _{prefix_}__num_ids"
|
||||
assert input_type_str0 == "string"
|
||||
|
||||
onnx_tokenizer_impl = "HfBertTokenizer" if self._tokenizer_param.is_sentence_pair else "BertTokenizer"
|
||||
|
||||
def build_output_declare():
|
||||
output_base = []
|
||||
for out in self.output_names:
|
||||
output_base.append(f"int64[{output_shape_str}] {out}")
|
||||
|
||||
return ",".join(output_base)
|
||||
|
||||
def get_tokenizer_ret():
|
||||
if onnx_tokenizer_impl == "HfBertTokenizer":
|
||||
return ",".join(self.output_names)
|
||||
# different output orders for BertTokenizer and HfBertTokenizer
|
||||
return "ids,types,mask"
|
||||
|
||||
def build_output_imp():
|
||||
if onnx_tokenizer_impl == "HfBertTokenizer":
|
||||
return ""
|
||||
|
||||
# BertTokenizer has different output dimensions
|
||||
ret_vars = get_tokenizer_ret().split(",")
|
||||
ret_vars[1], ret_vars[2] = ret_vars[2], ret_vars[1]
|
||||
output_str = []
|
||||
|
||||
for idx, out in enumerate(self.output_names):
|
||||
output_str.append(f"{out} = Unsqueeze({ret_vars[idx]}, i64_0)")
|
||||
|
||||
return "\n".join(output_str)
|
||||
|
||||
def build_input_declare():
|
||||
inputs = f"{input_type_str0}[{input_shape_str0}] {self.input_names[0]}"
|
||||
return inputs
|
||||
|
||||
def build_unsqueeze():
|
||||
if len(input_shape_0) == 1:
|
||||
return f"""
|
||||
input_with_batch = Unsqueeze({self.input_names[0]}, i64_0)
|
||||
"""
|
||||
else:
|
||||
return f"""
|
||||
input_with_batch = Identity({self.input_names[0]})
|
||||
"""
|
||||
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
{onnx_tokenizer_impl} ({build_input_declare()})
|
||||
=> ({build_output_declare()})
|
||||
{{
|
||||
i64_0 = Constant <value = int64[1] {{0}}> ()
|
||||
{build_unsqueeze()}
|
||||
{get_tokenizer_ret()} = com.microsoft.extensions.{onnx_tokenizer_impl} (input_with_batch)
|
||||
{build_output_imp()}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
bert_tokenizer_param = self._tokenizer_param
|
||||
token_model_attr = []
|
||||
|
||||
attrs = _vocab_to_dict(bert_tokenizer_param.vocab_or_file)
|
||||
attrs["strip_accents"] = bert_tokenizer_param.strip_accents
|
||||
attrs["do_lower_case"] = bert_tokenizer_param.do_lower_case
|
||||
|
||||
for attr in attrs:
|
||||
token_model_attr.append(onnx.helper.make_attribute(attr, attrs[attr]))
|
||||
|
||||
node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == onnx_tokenizer_impl)
|
||||
converter_graph.node[node_idx].attribute.extend(token_model_attr)
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
class BertTokenizerQADecoder(Step):
|
||||
def __init__(self, tokenizer_param: TokenizerParam, name: Optional[str] = None):
|
||||
"""
|
||||
Brief:
|
||||
Decode the input_ids to text
|
||||
Args:
|
||||
tokenizer_param: some essential info to build a tokenizer.
|
||||
you can create a TokenizerParam object like:
|
||||
tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, #vocab is dict or file_path)
|
||||
name: Optional name of step. Defaults to 'BertTokenizerQADecoder'
|
||||
"""
|
||||
super().__init__(
|
||||
["start_logits", "end_logits", "input_ids"], ["text"], name)
|
||||
self._tokenizer_param = tokenizer_param
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
||||
def build_input_declare():
|
||||
inputs = []
|
||||
for idx, inp in enumerate(self.input_names):
|
||||
input_type_str_x, input_shape_str_x = self._get_input_type_and_shape_strs(graph, idx)
|
||||
inputs.append(f"{input_type_str_x}[{input_shape_str_x}] {inp}")
|
||||
return ",".join(inputs)
|
||||
|
||||
# A unique name for output shape
|
||||
prefix_ = f'step_{self.step_num}'
|
||||
output_shape_str = f"_{prefix_}_any_len"
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
tokenizer_decoder ({build_input_declare()})
|
||||
=> (string[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
i64_em = Constant <value = int64[0] {{}}> ()
|
||||
i64_1 = Constant <value = int64[1] {{1}}> ()
|
||||
i64_0 = Constant <value = int64[1] {{0}}> ()
|
||||
i64_neg1 = Constant <value = int64[1] {{-1}}> ()
|
||||
|
||||
s_position = ArgMax<axis = -1, keepdims = 0>({self.input_names[0]})
|
||||
e_position = ArgMax<axis = -1, keepdims = 0>({self.input_names[1]})
|
||||
ee_position = Add(e_position,i64_1)
|
||||
u_i64_neg1 = Unsqueeze(i64_neg1, i64_0)
|
||||
slice_ids= Slice({self.input_names[2]}, s_position, ee_position, i64_neg1)
|
||||
{self.output_names[0]} = com.microsoft.extensions.BertTokenizerDecoder (slice_ids, i64_em)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
attrs = _vocab_to_dict(self._tokenizer_param.vocab_or_file)
|
||||
token_model_attr = []
|
||||
for attr in attrs:
|
||||
token_model_attr.append(onnx.helper.make_attribute(attr, attrs[attr]))
|
||||
|
||||
node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == "BertTokenizerDecoder")
|
||||
converter_graph.node[node_idx].attribute.extend(token_model_attr)
|
||||
|
||||
return converter_graph
|
|
@ -21,7 +21,8 @@ def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]):
|
|||
Returns:
|
||||
An onnx.ValueInfoProto that can be used as a new model input.
|
||||
"""
|
||||
tensor_type = onnx.helper.make_tensor_type_proto(elem_type=data_type, shape=shape)
|
||||
tensor_type = onnx.helper.make_tensor_type_proto(
|
||||
elem_type=data_type, shape=shape)
|
||||
return onnx.helper.make_value_info(name, tensor_type)
|
||||
|
||||
|
||||
|
@ -81,6 +82,32 @@ class IoMapEntry:
|
|||
consumer_idx: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class IOEntryValuePreserver:
|
||||
"""
|
||||
used to allow an output value to have multiple consumers,
|
||||
which is only possible when IoMapEntry is used to create those additional connections.
|
||||
|
||||
Generally, a connection consumes an output and an input, then the output is removed from the graph.
|
||||
This class enabled one-to-many connections by making the other consumers share the same output.
|
||||
|
||||
How this class works:
|
||||
1. when the IoMapEntry is created, this class will be created simultaneously.
|
||||
2. It records the producer and consumer steps, and the output index of the producer step.
|
||||
when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output.
|
||||
3. when graph merge happens, this class will check if the output is still in the graph, if not,
|
||||
it will add the output
|
||||
4. when consumer step is running, this class will be deactivated and remove output from preserved_list.
|
||||
"""
|
||||
|
||||
producer: Union["Step", str] = None
|
||||
consumer: Union["Step", str] = None
|
||||
# output index from the producer step
|
||||
producer_idx: int = 0
|
||||
is_active: bool = False
|
||||
output: str = None
|
||||
|
||||
|
||||
def sanitize_output_names(graph: onnx.GraphProto):
|
||||
"""
|
||||
Convert any usage of invalid characters like '/' and ';' in value names to '_'
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Двоичный файл не отображается.
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Двоичный файл не отображается.
|
@ -7,7 +7,7 @@ import io
|
|||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
@ -16,6 +16,9 @@ from distutils.version import LooseVersion
|
|||
# `pip install -e .` from the repo root.
|
||||
from onnxruntime_extensions import get_library_path
|
||||
from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp
|
||||
from onnxruntime_extensions.tools import pre_post_processing as pre_post_processing
|
||||
from onnxruntime_extensions.tools.pre_post_processing.steps import *
|
||||
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ort_ext_root = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
|
@ -59,8 +62,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
input_batch = input_tensor.unsqueeze(
|
||||
0).detach().cpu().numpy() # create a mini-batch as expected by the model
|
||||
|
||||
s = ort.InferenceSession(input_model)
|
||||
scores = s.run(None, {'x': np.array(input_batch)})
|
||||
s = ort.InferenceSession(input_model, providers=['CPUExecutionProvider'])
|
||||
scores = s.run(None, {"x": np.array(input_batch)})
|
||||
scores = np.squeeze(scores)
|
||||
|
||||
def softmax(x):
|
||||
|
@ -75,8 +78,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
probabilities = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider'])
|
||||
probabilities = s.run(None, {"image": np.array(input_bytes)})[0]
|
||||
probabilities = np.squeeze(probabilities) # remove batch dim
|
||||
return probabilities
|
||||
|
||||
|
@ -100,6 +103,7 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
# can still use PT pre-processing as it's using PIL for images.
|
||||
# Update the Normalize values to match TF requirements.
|
||||
from torchvision import transforms
|
||||
|
||||
input_image = Image.open(input_image_path)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
|
@ -108,12 +112,13 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
input_tensor = preprocess(input_image)
|
||||
input_batch = input_tensor.unsqueeze(
|
||||
0).detach().cpu().numpy() # create a mini-batch as expected by the model
|
||||
input_batch = np.transpose(input_batch, (0, 2, 3, 1)) # to NHWC format for TF input
|
||||
# create a mini-batch as expected by the model
|
||||
input_batch = (input_tensor.unsqueeze(0).detach().cpu().numpy())
|
||||
# to NHWC format for TF input
|
||||
input_batch = np.transpose(input_batch, (0, 2, 3, 1))
|
||||
|
||||
s = ort.InferenceSession(input_model)
|
||||
probabilities = s.run(None, {'input': np.array(input_batch)})[0]
|
||||
s = ort.InferenceSession(input_model, providers=['CPUExecutionProvider'])
|
||||
probabilities = s.run(None, {"input": np.array(input_batch)})[0]
|
||||
return np.squeeze(probabilities)
|
||||
|
||||
def new_output():
|
||||
|
@ -123,8 +128,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
probabilities = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider'])
|
||||
probabilities = s.run(None, {"image": np.array(input_bytes)})[0]
|
||||
return np.squeeze(probabilities) # remove batch dim
|
||||
|
||||
orig_results = orig_output()
|
||||
|
@ -158,14 +163,14 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider'])
|
||||
|
||||
result_bytes = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
result_bytes = s.run(None, {"image": np.array(input_bytes)})[0]
|
||||
|
||||
# convert from png to RGB to remove any png encoding diffs
|
||||
result_img = Image.open(io.BytesIO(result_bytes))
|
||||
result = np.array(result_img.convert('RGB'))
|
||||
expected = np.array(Image.open(expected_output_image_path).convert('RGB'))
|
||||
result = np.array(result_img.convert("RGB"))
|
||||
expected = np.array(Image.open(expected_output_image_path).convert("RGB"))
|
||||
|
||||
# check all pixel values are close. allowing for 0.1% of pixels to differ by 2 at the most.
|
||||
#
|
||||
|
@ -174,8 +179,147 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
|||
# whether avx512 is used or not. MacOS seems to be slightly worse though (max of 2)
|
||||
diffs = np.absolute(expected.astype(np.int32) - result.astype(np.int32))
|
||||
total = np.sum(diffs)
|
||||
print(f'Max diff:{diffs.max()} Total diffs:{total}')
|
||||
print(f"Max diff:{diffs.max()} Total diffs:{total}")
|
||||
self.assertTrue(diffs.max() < 3 and total < (result.size / 1000))
|
||||
|
||||
def create_pipeline_and_run_for_tokenizer(self, tokenizer_impl, tokenizer_type,
|
||||
tokenizer_parameters, output_model: Path):
|
||||
import onnx
|
||||
create_named_value = pre_post_processing.utils.create_named_value
|
||||
|
||||
inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])]
|
||||
|
||||
pipeline = pre_post_processing.PrePostProcessor(inputs)
|
||||
# ref_output = list(tokenizer(*input_text, return_tensors="np").values())
|
||||
|
||||
pipeline.add_pre_processing([tokenizer_impl])
|
||||
if tokenizer_type == "HfBertTokenizer_with_decoder":
|
||||
pipeline.add_post_processing([
|
||||
(BertTokenizerQADecoder(tokenizer_parameters), [
|
||||
pre_post_processing.utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)])
|
||||
])
|
||||
input_model = onnx.load(os.path.join(test_data_dir, "../bert_qa_decoder_base.onnx"))
|
||||
else:
|
||||
input_model = onnx.ModelProto()
|
||||
input_model.opset_import.extend([onnx.helper.make_operatorsetid("", 16)])
|
||||
new_model = pipeline.run(input_model)
|
||||
onnx.save_model(new_model, output_model)
|
||||
return
|
||||
|
||||
def test_sentencepiece_tokenizer(self):
|
||||
output_model = (Path(test_data_dir) / "../sentencePiece.onnx").resolve()
|
||||
|
||||
input_text = ("This is a test sentence",)
|
||||
ref_output = [np.array([[0, 3293, 83, 10, 3034, 149357, 2]]),
|
||||
np.array([[1, 1, 1, 1, 1, 1, 1]])]
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained("xlm-roberta-base")
|
||||
tokenizer_parameters = TokenizerParam(
|
||||
vocab_or_file=os.path.join(test_data_dir, "../sentencepiece.bpe.model"),
|
||||
tweaked_bos_id=0,
|
||||
)
|
||||
tokenizer_impl = SentencePieceTokenizer(tokenizer_parameters, add_eos=True, add_bos=True)
|
||||
self.create_pipeline_and_run_for_tokenizer(
|
||||
tokenizer_impl, "SentecePieceTokenizer", tokenizer_parameters, output_model)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"])
|
||||
|
||||
result = s.run(None, {s.get_inputs()[0].name: np.array([[*input_text]])})
|
||||
|
||||
# SentencePieceTokenizer in ORT is round to zero, so we need to use atol=1
|
||||
self.assertEqual(np.allclose(result[0], ref_output[0], atol=1), True)
|
||||
self.assertEqual(np.allclose(result[1], ref_output[1]), True)
|
||||
|
||||
def test_bert_tokenizer(self):
|
||||
output_model = (Path(test_data_dir) / "../bert_tokenizer.onnx").resolve()
|
||||
input_text = ("This is a test sentence",)
|
||||
ref_output = [
|
||||
np.array([[2, 236, 118, 16, 1566, 875, 643, 3]]),
|
||||
np.array([[0, 0, 0, 0, 0, 0, 0, 0]]),
|
||||
np.array([[1, 1, 1, 1, 1, 1, 1, 1]]),
|
||||
]
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert")
|
||||
tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../bert.vocab"),
|
||||
do_lower_case=True)
|
||||
tokenizer_impl = BertTokenizer(tokenizer_parameters)
|
||||
self.create_pipeline_and_run_for_tokenizer(
|
||||
tokenizer_impl, "BertTokenizer", tokenizer_parameters, output_model)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"])
|
||||
|
||||
result = s.run(None, {s.get_inputs()[0].name: np.array([[*input_text]])})
|
||||
|
||||
self.assertEqual(np.allclose(result[0], ref_output[0]), True)
|
||||
self.assertEqual(np.allclose(result[1], ref_output[2]), True)
|
||||
self.assertEqual(np.allclose(result[2], ref_output[1]), True)
|
||||
|
||||
def test_hfbert_tokenizer(self):
|
||||
output_model = (Path(test_data_dir) / "../hfbert_tokenizer.onnx").resolve()
|
||||
ref_output = ([
|
||||
np.array([[2, 236, 118, 16, 1566, 875, 643, 3, 236, 118, 978, 1566, 875, 643, 3]]),
|
||||
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]),
|
||||
np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
|
||||
]
|
||||
)
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert")
|
||||
tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../hfbert.vocab"),
|
||||
do_lower_case=True, is_sentence_pair=True)
|
||||
input_text = ("This is a test sentence", "This is another test sentence")
|
||||
tokenizer_impl = BertTokenizer(tokenizer_parameters)
|
||||
self.create_pipeline_and_run_for_tokenizer(
|
||||
tokenizer_impl, "HfBertTokenizer", tokenizer_parameters, output_model)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"])
|
||||
|
||||
result = s.run(None, {s.get_inputs()[0].name: np.array([[input_text[0], input_text[1]]])})
|
||||
|
||||
self.assertEqual(np.allclose(result[0], ref_output[0]), True)
|
||||
self.assertEqual(np.allclose(result[1], ref_output[2]), True)
|
||||
self.assertEqual(np.allclose(result[2], ref_output[1]), True)
|
||||
|
||||
def test_qatask_with_tokenizer(self):
|
||||
output_model = (Path(test_data_dir) / "../hfbert_tokenizer.onnx").resolve()
|
||||
ref_output = [np.array(["[CLS]"])]
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert")
|
||||
tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../hfbert.vocab"),
|
||||
do_lower_case=True, is_sentence_pair=True)
|
||||
input_text = ("This is a test sentence", "This is another test sentence")
|
||||
tokenizer_impl = BertTokenizer(tokenizer_parameters)
|
||||
self.create_pipeline_and_run_for_tokenizer(
|
||||
tokenizer_impl, "HfBertTokenizer_with_decoder", tokenizer_parameters, output_model)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"])
|
||||
|
||||
result = s.run(None, {s.get_inputs()[0].name: np.array([[input_text[0], input_text[1]]])})
|
||||
|
||||
self.assertEqual(result[0][0], ref_output[0][0])
|
||||
|
||||
# Corner Case
|
||||
def test_debug_step(self):
|
||||
import onnx
|
||||
|
||||
create_named_value = pre_post_processing.utils.create_named_value
|
||||
|
||||
# multiple DebugSteps are stringed together
|
||||
input_model_path = os.path.join(test_data_dir, "pytorch_super_resolution.onnx")
|
||||
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
pipeline = pre_post_processing.PrePostProcessor(inputs)
|
||||
# each Debug step adds a new model output
|
||||
post_processing = [pre_post_processing.Debug(1), pre_post_processing.Debug(1), pre_post_processing.Debug(1)]
|
||||
|
||||
pipeline.add_post_processing(post_processing)
|
||||
input_model = onnx.load(input_model_path)
|
||||
new_model = pipeline.run(input_model)
|
||||
|
||||
self.assertEqual(len(new_model.graph.output), len(input_model.graph.output) + len(post_processing))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import re
|
||||
import tempfile
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp
|
||||
import onnxruntime_extensions
|
||||
|
||||
# for tokenizer
|
||||
import transformers
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
|
||||
# avoid loading model from hugging-face multiple times, it's time consuming
|
||||
@functools.lru_cache
|
||||
def get_tokenizer_and_model_from_huggingface(model_name):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
config = transformers.AutoConfig.from_pretrained(model_name)
|
||||
|
||||
if model_name == "xlm-roberta-base":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
onnx_config = transformers.models.xlm_roberta.XLMRobertaOnnxConfig(config, "sequence-classification")
|
||||
text = ("Hello, my dog is cute",)
|
||||
elif model_name == "google/mobilebert-uncased":
|
||||
model = transformers.MobileBertForNextSentencePrediction.from_pretrained(model_name)
|
||||
onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "default")
|
||||
text = ("where is Jim Henson?", "he is at school from where two blocks away")
|
||||
elif model_name == "csarron/mobilebert-uncased-squad-v2":
|
||||
model = transformers.MobileBertForQuestionAnswering.from_pretrained(model_name)
|
||||
onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "question-answering")
|
||||
text = ("Who was Jim Henson?", "Jim Henson was a nice puppet")
|
||||
elif model_name == "lordtt13/emo-mobilebert":
|
||||
model = transformers.MobileBertForSequenceClassification.from_pretrained(model_name)
|
||||
onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "sequence-classification")
|
||||
text = ("Hello, my dog is cute",)
|
||||
else:
|
||||
raise ValueError(f"{model_name} is not supported yet.")
|
||||
return tokenizer, model, onnx_config, text
|
||||
|
||||
|
||||
def export_backbone(model_name: str, bert_onnx_model: Path):
|
||||
"""
|
||||
To export onnx model from huggingface.
|
||||
This model usually has inputs "input_ids", "attention_mask", "token_type_ids", and tensor outputs.
|
||||
"""
|
||||
|
||||
# fix the seed so we can reproduce the results
|
||||
transformers.set_seed(42)
|
||||
tokenizer, model, onnx_config, text = get_tokenizer_and_model_from_huggingface(model_name)
|
||||
|
||||
if bert_onnx_model and bert_onnx_model.exists():
|
||||
print("Using cached ONNX model, skipping re-exporting the backbone model.")
|
||||
return tokenizer, bert_onnx_model, onnx_config
|
||||
|
||||
# tempfile will be removed automatically
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
canonized_name = bert_onnx_model.name
|
||||
tmp_model_path = Path(tmpdir + "/" + canonized_name)
|
||||
onnx_inputs, onnx_outputs = transformers.onnx.export(tokenizer, model, onnx_config, 16, tmp_model_path)
|
||||
shutil.copy(tmp_model_path, bert_onnx_model)
|
||||
return tokenizer, bert_onnx_model, onnx_config
|
||||
|
||||
|
||||
def add_pre_post_processing_to_transformers(model_name: str, input_model_file: Path, output_model_file: Path):
|
||||
"""Construct the pipeline for an end2end model with pre and post processing.
|
||||
The final model can take text as inputs and output the result in text format for models like Q & A.
|
||||
|
||||
Args:
|
||||
model_name (str): Model to export from hugging-face. Used to infer tokenizer and onnx model backbone.
|
||||
input_model_file (Path): The onnx model needed to be saved/cached, if not provided, will export from hugging-face.
|
||||
output_model_file (Path): where to save the final onnx model.
|
||||
"""
|
||||
tokenizer, bert_onnx_model, onnx_config = export_backbone(model_name, input_model_file)
|
||||
if not hasattr(tokenizer, "vocab_file"):
|
||||
vocab_file = bert_onnx_model.parent / "vocab.txt"
|
||||
import json
|
||||
with open(str(vocab_file), 'w') as f:
|
||||
f.write(json.dumps(tokenizer.vocab))
|
||||
else:
|
||||
vocab_file = tokenizer.vocab_file
|
||||
tokenizer_type = 'BertTokenizer' if model_name != 'xlm-roberta-base' else 'SentencePieceTokenizer'
|
||||
task_type = ('NextSentencePrediction' if model_name == 'google/mobilebert-uncased'
|
||||
else ''.join([i.capitalize() for i in onnx_config.task.split('-')]))
|
||||
add_ppp.transformers_and_bert(bert_onnx_model, output_model_file,
|
||||
vocab_file, tokenizer_type,
|
||||
task_type,
|
||||
add_debug_before_postprocessing=True)
|
||||
|
||||
|
||||
def verify_results_for_e2e_model(model_name: str, input_bert_model: Path, output_model_file: Path):
|
||||
"""
|
||||
Args:
|
||||
output_model_file: the onnx model which finalized and needs to be verified
|
||||
model_name: the huggingface model name
|
||||
input_bert_model: the onnx model which is generated by huggingface or user provide
|
||||
"""
|
||||
tokenizer, hg_model, _, text = get_tokenizer_and_model_from_huggingface(model_name)
|
||||
encoded_input = tokenizer(*text, return_tensors="pt")
|
||||
transformers.set_seed(42)
|
||||
|
||||
session_options = onnxruntime.SessionOptions()
|
||||
|
||||
output_name_for_verify = ''
|
||||
session = onnxruntime.InferenceSession(
|
||||
str(input_bert_model.resolve(strict=True)), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
inputs = {key: value.detach().numpy() for key, value in encoded_input.items()}
|
||||
output_name_for_verify = session.get_outputs()[0].name
|
||||
ref_outputs = session.run([output_name_for_verify], inputs)
|
||||
|
||||
# Load tokenizer op
|
||||
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
|
||||
|
||||
session = onnxruntime.InferenceSession(
|
||||
str(output_model_file.resolve(strict=True)), session_options, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
|
||||
inputs = dict(input_text=np.array([[*text]]))
|
||||
real_outputs = session.run([output_name_for_verify+"_debug"], inputs)
|
||||
assert np.allclose(
|
||||
real_outputs[0], ref_outputs[0], atol=1e-2, rtol=1e-6
|
||||
), f"Results do not match, expected:{ref_outputs[0]}, but got {real_outputs[0] }"
|
||||
|
||||
print("Results matches:", real_outputs[0], "\ndiff:", real_outputs[0] - ref_outputs[0])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
description="""Add pre and post processing to a model.
|
||||
|
||||
This tutorial supports updating:
|
||||
- MobileBert with different tasks
|
||||
- XLM-Roberta with classification task
|
||||
|
||||
This tutorial provides an example of how to add pre/post processing to a transformer model.
|
||||
It can add a tokenizer (SentencePiece/Berttokenizer/HfbertTokenizer) for pre-processing,
|
||||
and a classifier/decoder for post-processing.
|
||||
|
||||
Exports models from huggingface by default if an existing onnx model is not provided.
|
||||
NOTE: if providing a onnx model, you have to make sure your model is matched with the model_type in hugging-face as we are using the hugging-face tokenizer to do the pre-processing.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"xlm-roberta-base",
|
||||
"google/mobilebert-uncased",
|
||||
"csarron/mobilebert-uncased-squad-v2",
|
||||
"lordtt13/emo-mobilebert",
|
||||
],
|
||||
help="Model type.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"model_path",
|
||||
type=Path,
|
||||
help="""The path to an existing ONNX model or directory name to save a model exported from HuggingFace in.
|
||||
This model will be updated to add pre/post processing, and saved in the same location with the suffix
|
||||
'.with_pre_post_processing.onnx'""",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path.resolve(strict=True)
|
||||
canonized_name = re.sub(r"[^a-zA-Z0-9]", "_", args.model_type) + ".onnx"
|
||||
|
||||
if model_path.is_dir():
|
||||
model_path = model_path / canonized_name
|
||||
|
||||
new_model_path = model_path.with_suffix(".with_pre_post_processing.onnx")
|
||||
|
||||
add_pre_post_processing_to_transformers(args.model_type, model_path, new_model_path)
|
||||
verify_results_for_e2e_model(args.model_type, model_path, new_model_path)
|
||||
return new_model_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Загрузка…
Ссылка в новой задаче