diff --git a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py index e426a6e0..f4dfaf07 100644 --- a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +++ b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py @@ -7,7 +7,7 @@ import onnx import os from pathlib import Path - +from typing import Union # NOTE: If you're working on this script install onnxruntime_extensions using `pip install -e .` from the repo root # and run with `python -m onnxruntime_extensions.tools.add_pre_post_processing_to_model` # Running directly will result in an error from a relative import. @@ -162,21 +162,110 @@ def superresolution(model_file: Path, output_file: Path, output_format: str, onn onnx.save_model(new_model, str(output_file.resolve())) +class NLPTaskType(enum.Enum): + TokenClassification = enum.auto() + QuestionAnswering = enum.auto() + SequenceClassification = enum.auto() + NextSentencePrediction = enum.auto() + + +class TokenizerType(enum.Enum): + BertTokenizer = enum.auto() + SentencePieceTokenizer = enum.auto() + + +def transformers_and_bert( + input_model_file: Path, + output_model_file: Path, + vocab_file: Path, + tokenizer_type: Union[TokenizerType, str], + task_type: Union[NLPTaskType, str], + onnx_opset: int = 16, + add_debug_before_postprocessing=False, +): + """construct the pipeline for a end2end model with pre and post processing. The final model can take text as inputs + and output the result in text format for model like QA. + + Args: + input_model_file (Path): the model file needed to be updated. + output_model_file (Path): where to save the final onnx model. + vocab_file (Path): the vocab file for the tokenizer. + task_type (Union[NLPTaskType, str]): the task type of the model. + onnx_opset (int, optional): the opset version to use. Defaults to 16. + add_debug_before_postprocessing (bool, optional): whether to add a debug step before post processing. + Defaults to False. + """ + if isinstance(task_type, str): + task_type = NLPTaskType[task_type] + if isinstance(tokenizer_type, str): + tokenizer_type = TokenizerType[tokenizer_type] + + onnx_model = onnx.load(str(input_model_file.resolve(strict=True))) + # hardcode batch size to 1 + inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])] + + pipeline = PrePostProcessor(inputs, onnx_opset) + tokenizer_args = TokenizerParam( + vocab_or_file=vocab_file, + do_lower_case=True, + tweaked_bos_id=0, + is_sentence_pair=True if task_type in [NLPTaskType.QuestionAnswering, + NLPTaskType.NextSentencePrediction] else False, + ) + + preprocessing = [ + SentencePieceTokenizer(tokenizer_args) + if tokenizer_type == TokenizerType.SentencePieceTokenizer else BertTokenizer(tokenizer_args), + # uncomment this line to debug + # Debug(2), + ] + + # For verify results with out postprocessing + postprocessing = [Debug()] if add_debug_before_postprocessing else [] + if task_type == NLPTaskType.QuestionAnswering: + postprocessing.append((BertTokenizerQADecoder(tokenizer_args), [ + # input_ids + utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)])) + elif task_type == NLPTaskType.SequenceClassification: + postprocessing.append(ArgMax()) + # the other tasks don't need postprocessing or we don't support it yet. + + pipeline.add_pre_processing(preprocessing) + pipeline.add_post_processing(postprocessing) + + new_model = pipeline.run(onnx_model) + onnx.save_model(new_model, str(output_model_file.resolve())) + + def main(): parser = argparse.ArgumentParser( os.path.basename(__file__), description="""Add pre and post processing to a model. Currently supports updating: - - super resolution with YCbCr input - - imagenet trained mobilenet + Vision models: + - super resolution with YCbCr input + - imagenet trained mobilenet + NLP models: + + - MobileBert with different tasks + - XLM-Roberta with classification task - To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide. + For Vision models: + To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide. Create a pipeline and add the required pre/post processing 'Steps' in the order required. Configure - individual steps as needed. + individual steps as needed. + + For NLP models: + `transformers_and_bert` can be used for MobileBert QuestionAnswering/Classification tasks, + or serve as a guide of how to add pre/post processing to a transformer model. + Usually pre-processing includes adding a tokenizer. Post-processing includes conversion of output_ids to text. + + You might need to pass the tokenizer model file (bert vocab file or SentencePieceTokenizer model) + and task_type to the function. - The updated model will be written in the same location as the original model, with '.onnx' updated to - '.with_pre_post_processing.onnx' + The updated model will be written in the same location as the original model, + with '.onnx' updated to '.with_pre_post_processing.onnx' """, ) @@ -185,7 +274,11 @@ def main(): "--model_type", type=str, required=True, - choices=["superresolution", "mobilenet"], + choices=[ + "superresolution", + "mobilenet", + "transformers", + ], help="Model type.", ) @@ -212,9 +305,35 @@ def main(): help="Image output format for superresolution model to produce.", ) + parser.add_argument( + "--nlp_task_type", + type=str, + choices=["QuestionAnswering", + "SequenceClassification", + "NextSentencePrediction"], + required=False, + help="The downstream task for NLP model.", + ) + + parser.add_argument( + "--vocab_file", + type=Path, + required=False, + help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.", + ) + + parser.add_argument( + "--tokenizer_type", + type=str, + choices=["BertTokenizer", + "SentencePieceTokenizer"], + required=False, + help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.", + ) + parser.add_argument( "--opset", type=int, required=False, default=16, - help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing." + help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing.", ) parser.add_argument("model", type=Path, help="Provide path to ONNX model to update.") @@ -227,8 +346,13 @@ def main(): if args.model_type == "mobilenet": source = ModelSource.PYTORCH if args.model_source == "pytorch" else ModelSource.TENSORFLOW mobilenet(model_path, new_model_path, source, args.opset) + elif args.model_type == "superresolution": + superresolution(model_path, new_model_path, + args.output_format, args.opset) else: - superresolution(model_path, new_model_path, args.output_format, args.opset) + if args.vocab_file is None or args.nlp_task_type is None or args.tokenizer_type is None: + parser.error("Please provide vocab file/nlp_task_type/tokenizer_type.") + transformers_and_bert(model_path, new_model_path, args.tokenizer_type, args.vocab_file, args.nlp_task_type) if __name__ == "__main__": diff --git a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/step.md b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/step.md index 580a8828..3651680f 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/step.md +++ b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/step.md @@ -8,10 +8,9 @@ Classes : Step that can be arbitrarily inserted in the pre or post processing pipeline. It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged. - NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it - may or may not have '_debug' as a suffix. - TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node - to rename so it's more consistent. + The output will be duplicated into two outputs, one will be renamed with a suffix "_next", + another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step, + the "_debug" outputs will become graph outputs. Initialize Debug step Args: @@ -35,11 +34,15 @@ Classes ### Descendants * pre_post_processing.step.Debug + * pre_post_processing.steps.general.ArgMax * pre_post_processing.steps.general.ReverseAxis * pre_post_processing.steps.general.Softmax * pre_post_processing.steps.general.Squeeze * pre_post_processing.steps.general.Transpose * pre_post_processing.steps.general.Unsqueeze + * pre_post_processing.steps.nlp.BertTokenizer + * pre_post_processing.steps.nlp.BertTokenizerQADecoder + * pre_post_processing.steps.nlp.SentencePieceTokenizer * pre_post_processing.steps.vision.CenterCrop * pre_post_processing.steps.vision.ConvertBGRToImage * pre_post_processing.steps.vision.ConvertImageToBGR @@ -57,9 +60,12 @@ Classes ### Methods - `apply(self, graph: onnx.onnx_ml_pb2.GraphProto, checker_context: onnx.onnx_cpp2py_export.checker.CheckerContext)` + `apply(self, graph: onnx.onnx_ml_pb2.GraphProto, checker_context: onnx.onnx_cpp2py_export.checker.CheckerContext, graph_outputs_to_maintain: List[str])` : Create a graph for this step that can be appended to the provided graph. The PrePostProcessor will handle merging the two. + graph_outputs_to_maintain: List of output names to maintain in the graph. + Some outputs might be consumed during graph merging, so we have to explicitly maintain it for subsequent steps. + This outputs is generated by IOEntryValuePreserver. `connect(self, entry: pre_post_processing.utils.IoMapEntry)` : Connect the value name from a previous step to an input of this step so they match. diff --git a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/general.md b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/general.md index 2e61ced6..d251f46d 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/general.md +++ b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/general.md @@ -4,6 +4,18 @@ Module pre_post_processing.steps.general Classes ------- +`ArgMax(name: Optional[str] = None, axis: int = -1, keepdims: int = 0)` +: Base class for a pre or post processing step. + + Brief: + Same as ArgMax op. + Args: + name: Optional name of step. Defaults to 'ArgMax' + + ### Ancestors (in MRO) + + * pre_post_processing.step.Step + `ReverseAxis(axis: int = -1, dim_value: int = -1, name: Optional[str] = None)` : Reverses the data in an axis by splitting and concatenating in reverse order. e.g. convert RGB ordered data to BGR. diff --git a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/index.md b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/index.md index 7524c03a..e0dc0866 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/index.md +++ b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/index.md @@ -4,4 +4,5 @@ Module pre_post_processing.steps Sub-modules ----------- * pre_post_processing.steps.general +* pre_post_processing.steps.nlp * pre_post_processing.steps.vision \ No newline at end of file diff --git a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/nlp.md b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/nlp.md new file mode 100644 index 00000000..21c1f6f7 --- /dev/null +++ b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/steps/nlp.md @@ -0,0 +1,68 @@ +Module pre_post_processing.steps.nlp +==================================== + +Classes +------- + +`BertTokenizer(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, name: Optional[str] = None)` +: Base class for a pre or post processing step. + + Brief: This step is used to convert the input text into the input_ids, attention_mask, token_type_ids. + It supports an input of a single string for classification models, or two strings for QA models. + Args: + tokenizer_param: some essential infos to build a tokenizer, + You can create a TokenizerParam like this: + tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, # vocab is dict or file_path + strip_accents = True or False (Optional), + do_lower_case = True or False (Optional) + ) + + name: Optional name of step. Defaults to 'BertTokenizer' + + ### Ancestors (in MRO) + + * pre_post_processing.step.Step + +`BertTokenizerQADecoder(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, name: Optional[str] = None)` +: Base class for a pre or post processing step. + + Brief: + Decode the input_ids to text + Args: + tokenizer_param: some essential info to build a tokenizer. + you can create a TokenizerParam object like: + tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, #vocab is dict or file_path) + name: Optional name of step. Defaults to 'BertTokenizerQADecoder' + + ### Ancestors (in MRO) + + * pre_post_processing.step.Step + +`SentencePieceTokenizer(tokenizer_param: pre_post_processing.steps.nlp.TokenizerParam, nbest_size=0, alpha=1.0, reverse=False, add_bos=False, add_eos=False, name: Optional[str] = None)` +: Base class for a pre or post processing step. + + Brief: + SentencePieceTokenizer has actually 6 inputs in definition, but we allow user to provide only text input, + and make the others, "nbest_size", "alpha", "add_bos", "add_eos", "reverse" optional. + Args: + tokenizer_param: some essential infos to build a tokenizer + you can create a TokenizerParam object like: + tokenizer_param = TokenizerParam(vocab_size=tokenizer.vocab_size, + tweaked_bos_id=tokenizer.tweaked_bos_id) + + nbest_size: int, optional (default = 0) + alpha: float, optional (default = 1.0) + reverse: bool, optional (default = False) + add_bos: bool, optional (default = False) + add_eos: bool, optional (default = False) + Please see more detail explanation in + https://www.tensorflow.org/text/api_docs/python/text/SentencepieceTokenizer#args + + name: Optional name of step. Defaults to 'SentencePieceTokenizer' + + ### Ancestors (in MRO) + + * pre_post_processing.step.Step + +`TokenizerParam(vocab_or_file: Union[pathlib.Path, dict], **kwargs)` +: \ No newline at end of file diff --git a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/utils.md b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/utils.md index 672d6c66..57359102 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/utils.md +++ b/onnxruntime_extensions/tools/pre_post_processing/docs/pre_post_processing/utils.md @@ -46,6 +46,38 @@ Functions Classes ------- +`IOEntryValuePreserver(producer: Union[ForwardRef('Step'), str] = None, consumer: Union[ForwardRef('Step'), str] = None, producer_idx: int = 0, is_active: bool = False, output: str = None)` +: used to allow an output value to have multiple consumers, + which is only possible when IoMapEntry is used to create those additional connections. + + Generally, a connection consumes an output and an input, then the output is removed from the graph. + This class enabled one-to-many connections by making the other consumers share the same output. + + How this class works: + 1. when the IoMapEntry is created, this class will be created simultaneously. + 2. It records the producer and consumer steps, and the output index of the producer step. + when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output. + 3. when graph merge happens, this class will check if the output is still in the graph, if not, + it will add the output + 4. when consumer step is running, this class will be deactivated and remove output from preserved_list. + + ### Class variables + + `consumer: Union[Step, str]` + : + + `is_active: bool` + : + + `output: str` + : + + `producer: Union[Step, str]` + : + + `producer_idx: int` + : + `IoMapEntry(producer: Union[ForwardRef('Step'), str] = None, producer_idx: int = 0, consumer_idx: int = 0)` : Entry to map the output index from a producer step to the input index of a consumer step. diff --git a/onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py b/onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py index 994899d5..d4960422 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +++ b/onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py @@ -8,6 +8,7 @@ from typing import List, Tuple, Union from .utils import ( IoMapEntry, + IOEntryValuePreserver, create_custom_op_checker_context, sanitize_output_names, TENSOR_TYPE_TO_ONNX_TYPE, @@ -56,6 +57,12 @@ class PrePostProcessor: self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]] self._inputs = inputs if inputs else [] + + # preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps. + # we now can support a output value has more than one consumers with IOEntryValuePreserver. + # IOEntryValuePreserver will preserve the output value and add it to the graph output + # until consumer step is done. + self.outputs_preserver = [] # type: List[IOEntryValuePreserver] def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]): """ @@ -140,12 +147,31 @@ class PrePostProcessor: n.name = prefix + str(idx) idx += 1 + def preserved_apply(processor: Step, *args): + # Trying to activate the IOEntryValuePreserver and preserve outputs. + # and deactivate the outputs when the current graph consumed them + + for preserver in self.outputs_preserver: + if preserver.consumer == processor: + preserver.is_active = False + + # IOEntryValuePreserver, preserve those outputs which has multiple consumers. + # we explicitly add the output to the graph output. + graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active] + graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain) + + for preserver in self.outputs_preserver: + if preserver.producer == processor: + preserver.is_active = True + preserver.output = processor.output_names[preserver.producer_idx] + return graph_for_step + def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]): for connection in connections: assert connection.producer self._add_connection(processor, connection) - return processor.apply(graph, self._custom_op_checker_context) + return preserved_apply(processor, graph, self._custom_op_checker_context) # fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them if self.post_processors: @@ -173,11 +199,22 @@ class PrePostProcessor: self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)] # map the pre-processing outputs to graph inputs + # we may need a natty way to get possible outputs after merge_graphs + step_graph_outputs = [o.name for o in pre_process_graph.output] io_map = [] # type: List[Tuple[str, str]] for step, step_idx, graph_input in self._pre_processing_joins: io_map.append((step.output_names[step_idx], graph_input)) + step_graph_outputs.remove((step.output_names[step_idx])) - graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map) + # add outputs from previous IoMapEntry producers to maintain them as graph outputs + # until consumed by the final Step that requires them. + step_graph_outputs += [ + o.name for o in graph.output if o.name not in step_graph_outputs] + external_outputs = [ + i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs] + if external_outputs: + step_graph_outputs.extend(external_outputs) + graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs) # add post-processing if self.post_processors: @@ -274,6 +311,7 @@ class PrePostProcessor: producer = self.__producer_from_step_or_str(entry.producer) # throws if not found io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx) + self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx)) processors.append(step) processor_connections.append([entry for entry in io_map_entries if entry is not None]) diff --git a/onnxruntime_extensions/tools/pre_post_processing/step.py b/onnxruntime_extensions/tools/pre_post_processing/step.py index f09557c2..26017c49 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/step.py +++ b/onnxruntime_extensions/tools/pre_post_processing/step.py @@ -48,10 +48,18 @@ class Step(object): self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx] - def apply(self, graph: onnx.GraphProto, checker_context: onnx.checker.C.CheckerContext): + def apply(self, graph: onnx.GraphProto, + checker_context: onnx.checker.C.CheckerContext, + graph_outputs_to_maintain: List[str]): """ Create a graph for this step that can be appended to the provided graph. The PrePostProcessor will handle merging the two. + + Args: + graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort. + For outputs having multiple consumers, these outputs will be consumed by default and prevent + connection from the subsequent steps. + This outputs is generated by IOEntryValuePreserver. """ onnx_opset = checker_context.opset_imports[""] @@ -60,7 +68,7 @@ class Step(object): # prefix the graph for this step to guarantee no clashes of value names with the existing graph onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True) - result = self.__merge(graph, graph_for_step) + result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain) # update self.output_names to the prefixed names so that when we connect later Steps the values match new_outputs = [self._prefix + o for o in self.output_names] @@ -86,8 +94,10 @@ class Step(object): """ pass - def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto): + def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto, + graph_outputs_to_maintain: Optional[List[str]] = None): # We prefixed all the value names in `second`, so allow for that when connecting the two graphs + first_output = [o.name for o in first.output] io_map = [] for o in first.output: # apply the same prefix to the output from the previous step to match the prefixed graph from this step @@ -95,24 +105,13 @@ class Step(object): for i in second.input: if i.name == prefixed_output: io_map.append((o.name, i.name)) - - outputs_to_preserve = None - - # special handling of Debug class. - if isinstance(self, Debug): - # preserve outputs of the first graph so they're available downstream. otherwise they are consumed by - # the Debug node and disappear during the ONNX graph_merge as it considers consumed values to be - # internal - which is entirely reasonable when merging graphs. - # the issue we have is that we don't know what future steps might want things to remain as outputs. - # the current approach is to insert a Debug step which simply duplicates the values so that they are - # guaranteed not be consumed (only one of the two copies will be used). - # doesn't change the number of outputs from the previous step, so it can be transparently inserted in the - # pre/post processing pipeline. - # need to also list the second graph's outputs when manually specifying outputs. - outputs_to_preserve = [o.name for o in first.output] + [o.name for o in second.output] - + first_output.remove(o.name) + + graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output] + graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs] + # merge with existing graph - merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=outputs_to_preserve) + merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs) return merged_graph @@ -154,16 +153,14 @@ class Step(object): return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape) -# special case. we include the helper Debug step here as logic in the base class is conditional on it. class Debug(Step): """ Step that can be arbitrarily inserted in the pre or post processing pipeline. It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged. - NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it - may or may not have '_debug' as a suffix. - TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node - to rename so it's more consistent. + The output will be duplicated into two outputs, one will be renamed with a suffix "_next", + another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step, + the "_debug" outputs will become graph outputs. """ def __init__(self, num_inputs: int = 1, name: Optional[str] = None): @@ -180,32 +177,43 @@ class Debug(Step): super().__init__(input_names, output_names, name) def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): - input_str = "" - output_str = "" - output_debug_str = "" - nodes_str = "" + if self._num_inputs > len(graph.output): + raise ValueError( + f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.") + debug_offset = len(self.input_names) # update output names so we preserve info from the latest input names - self.output_names = [f"{name}_debug" for name in self.input_names] + self.output_names = [f"{name}_next" for name in self.input_names] + self.output_names += [f"{name}_debug" for name in self.input_names] + input_str_list = [] + output_str_list = [] + nodes_str_list = [] for i in range(0, self._num_inputs): - input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, i) - if i > 0: - input_str += ", " - output_str += ", " - output_debug_str += ", " - nodes_str += "\n" + input_type_str, input_shape_str = self._get_input_type_and_shape_strs( + graph, i) - input_str += f"{input_type_str}[{input_shape_str}] {self.input_names[i]}" - output_str += f"{input_type_str}[{input_shape_str}] {self.output_names[i]}" - nodes_str += f"{self.output_names[i]} = Identity({self.input_names[i]})\n" + input_str_list.append( + f"{input_type_str}[{input_shape_str}] {self.input_names[i]}") + output_str_list.append( + f"{input_type_str}[{input_shape_str}] {self.output_names[i]}") + output_str_list.append( + f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}") + + nodes_str_list.append( + f"{self.output_names[i]} = Identity({self.input_names[i]})\n") + nodes_str_list.append( + f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n") + + # f-string can't have back-slash + node_str = '\n'.join(nodes_str_list) debug_graph = onnx.parser.parse_graph( f"""\ - debug ({input_str}) - => ({output_str}) + debug ({','.join(input_str_list)}) + => ({','.join(output_str_list)}) {{ - {nodes_str} + {node_str} }} """ ) diff --git a/onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py b/onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py index dc7ec5e8..cd557903 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +++ b/onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py @@ -3,3 +3,4 @@ from .general import * from .vision import * +from .nlp import * diff --git a/onnxruntime_extensions/tools/pre_post_processing/steps/general.py b/onnxruntime_extensions/tools/pre_post_processing/steps/general.py index db6f217f..b9e35e8a 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/steps/general.py +++ b/onnxruntime_extensions/tools/pre_post_processing/steps/general.py @@ -210,3 +210,45 @@ class Unsqueeze(Step): ) return unsqueeze_graph + + +class ArgMax(Step): + def __init__(self, name: Optional[str] = None, axis: int = -1, keepdims: int = 0): + """ + Brief: + Same as ArgMax op. + Args: + name: Optional name of step. Defaults to 'ArgMax' + + """ + super().__init__(["data"], ["index"], name) + self._axis = axis + self._keepdims = keepdims + + def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): + input_type_str_0, input_shape_str_0 = self._get_input_type_and_shape_strs(graph, 0) + input_shape_0 = input_shape_str_0.split(",") + + def build_input_declare(): + return f"{input_type_str_0}[{input_shape_str_0}] {self.input_names[0]}" + + axis = self._axis + len(input_shape_0) if self._axis < 0 else self._axis + if axis >= len(input_shape_0): + raise ValueError("axis should be in range [-rank, rank-1].") + + output_shape_str = input_shape_0.copy() + output_shape_str[axis] = "1" + if self._keepdims == 0: + output_shape_str.pop(axis) + + converter_graph = onnx.parser.parse_graph( + f"""\ + classify ({build_input_declare()}) + => (int64[{','.join(output_shape_str)}] {self.output_names[0]}) + {{ + {self.output_names[0]} = ArgMax({self.input_names[0]}) + }} + """ + ) + + return converter_graph diff --git a/onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py b/onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py new file mode 100644 index 00000000..31fc906e --- /dev/null +++ b/onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from collections import OrderedDict +from pathlib import Path + +from typing import Optional, Union, Dict +from ..step import Step + + +class TokenizerParam(object): + def __init__(self, vocab_or_file: Union[Path, dict], **kwargs): + self.vocab_or_file = vocab_or_file + self.tweaked_bos_id = 1 + self.strip_accents = 0 + self.do_lower_case = 0 + self.is_sentence_pair = 0 + self.__assigned_with_kwargs(**kwargs) + + def __assigned_with_kwargs(self, **kwargs): + for key in self.__dict__.keys(): + if key in kwargs and kwargs.get(key) is not None: + setattr(self, key, kwargs[key]) + + +class SentencePieceTokenizer(Step): + def __init__( + self, + tokenizer_param: TokenizerParam, + nbest_size=0, + alpha=1.0, + reverse=False, + add_bos=False, + add_eos=False, + name: Optional[str] = None, + ): + """ + Brief: + SentencePieceTokenizer has actually 6 inputs in definition, but we allow user to provide only text input, + and make the others, "nbest_size", "alpha", "add_bos", "add_eos", "reverse" optional. + Args: + tokenizer_param: some essential infos to build a tokenizer + you can create a TokenizerParam object like: + tokenizer_param = TokenizerParam(vocab_size=tokenizer.vocab_size, + tweaked_bos_id=tokenizer.tweaked_bos_id) + + nbest_size: int, optional (default = 0) + alpha: float, optional (default = 1.0) + reverse: bool, optional (default = False) + add_bos: bool, optional (default = False) + add_eos: bool, optional (default = False) + Please see more detail explanation in + https://www.tensorflow.org/text/api_docs/python/text/SentencepieceTokenizer#args + + name: Optional name of step. Defaults to 'SentencePieceTokenizer' + + """ + super().__init__( + ["input_text", "nbest_size", "alpha", "add_bos", "add_eos", "reverse"], ["input_ids", "attention_mask"], name + ) + self._tokenizer_param = tokenizer_param + # python bool value (True/False) is not supported in c++, so we use 0/1 to represent bool + self._optional_kwargs = dict( + nbest_size=nbest_size, alpha=alpha, add_bos=int(add_bos), add_eos=int(add_eos), reverse=int(reverse) + ) + + def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): + # input text + input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0) + input_shape_0 = input_shape_str0.split(",") + # ideally, we should support batch input, each batch has different length and output a token + # !!! But, the implementation of SentencePieceTokenizer is not batch supported, inputs will be flatten to 1D + # in the sentence-piece kernel + assert input_type_str0 == "string" + + # we have to do this hack here, because some models tweaked bos_id to 0, but we have still 1 + # as default value in model file. + # it is only a temporary solution, we will remove it in the future. + tweak_bos_id = False + if self._tokenizer_param.tweaked_bos_id != 1 and self._optional_kwargs["add_bos"]: + self._optional_kwargs["add_bos"] = 0 + tweak_bos_id = True + + batch_dim = input_shape_0[0] if len(input_shape_0) > 1 else "1" + prefix_ = f'step_{self.step_num}' + output_shape_str = f"{batch_dim}, {prefix_}__num_ids" + + def build_input_declare(): + input_base = f"{input_type_str0}[{input_shape_str0}] {self.input_names[0]}" + return input_base + + def build_call_para(): + para_base = ["input_with_batch"] + para_base.append("i64_nbest_size") + para_base.append("f32_alpha") + para_base.append("bool_add_bos") + para_base.append("bool_add_eos") + para_base.append("bool_reverse") + return ",".join(para_base) + + def build_forward_declare(): + # default values for nbest_size, alpha, add_bos, add_eos, reverse + declare_base = [ + f"i64_nbest_size = Constant ()", + f"f32_alpha = Constant ()", + f"bool_add_bos = Constant ()", + f"bool_add_eos = Constant ()", + f"bool_reverse = Constant ()", + ] + + return "\n".join(declare_base) + + # TODO Camembert and XLMRoberta tokenizers has a different bos_token_id (0) from the default value (1) + # Now, we are hacking it. + + def hack_bos_id(): + if tweak_bos_id: + return f''' + k_start = Constant () + input_ids_concat02 = Concat (k_start, token) + input_ids_bdim = Unsqueeze(input_ids_concat02, i64_0) + ''' + else: + return ''' + input_ids_bdim = Unsqueeze(token, i64_0) + ''' + + def build_unsqueeze(): + if len(input_shape_0) == 1: + return f""" + input_with_batch = Unsqueeze({self.input_names[0]}, i64_0) + """ + else: + return f""" + input_with_batch = Identity({self.input_names[0]}) + """ + + converter_graph = onnx.parser.parse_graph( + f"""\ + SentencePiecetokenizer ({build_input_declare()}) + => (int64[{output_shape_str}] {self.output_names[0]},int64[{output_shape_str}] {self.output_names[1]}) + {{ + {build_forward_declare()} + i64_neg1 = Constant () + i64_0 = Constant () + {build_unsqueeze()} + token,idx = com.microsoft.extensions.SentencepieceTokenizer ({build_call_para()}) + {hack_bos_id()} + {self.output_names[0]} = Cast (input_ids_bdim) + attention_mask_i32=Greater({self.output_names[0]}, i64_neg1) + {self.output_names[1]} = Cast (attention_mask_i32) + }} + """ + ) + + with open(self._tokenizer_param.vocab_or_file, "rb") as f: + content = f.read() + + token_model_attr = onnx.helper.make_attribute("model", content) + node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == "SentencepieceTokenizer") + converter_graph.node[node_idx].attribute.append(token_model_attr) + + return converter_graph + + +def _vocab_to_dict(vocab_or_file: Union[Dict[str, int], Path, str]): + if isinstance(vocab_or_file, (Path, str)): + # read from file + import json + with open(vocab_or_file, "r") as f: + vocab = json.load(f) + else: + vocab = vocab_or_file + + ordered_vocab = OrderedDict(sorted(vocab.items(), key=lambda item: int(item[1]))) + + vocab = "\n".join(ordered_vocab.keys()) + return dict(vocab_file=vocab) + + +class BertTokenizer(Step): + def __init__(self, tokenizer_param: TokenizerParam, name: Optional[str] = None): + """ + Brief: This step is used to convert the input text into the input_ids, attention_mask, token_type_ids. + It supports an input of a single string for classification models, or two strings for QA models. + Args: + tokenizer_param: some essential infos to build a tokenizer, + You can create a TokenizerParam like this: + tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, # vocab is dict or file_path + strip_accents = True or False (Optional), + do_lower_case = True or False (Optional) + ) + + name: Optional name of step. Defaults to 'BertTokenizer' + + """ + super().__init__(["input_text"], ["input_ids", "attention_mask", "token_type_ids"], name) + self._tokenizer_param = tokenizer_param + + def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): + input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0) + + input_shape_0 = input_shape_str0.split(",") + prefix_ = f'step_{self.step_num}' + # only support bath size 1 until tokenizer op supports batch size > 1 + batch_dim = input_shape_0[0] if len(input_shape_0) > 1 else "1" + output_shape_str = f"{batch_dim}, _{prefix_}__num_ids" + assert input_type_str0 == "string" + + onnx_tokenizer_impl = "HfBertTokenizer" if self._tokenizer_param.is_sentence_pair else "BertTokenizer" + + def build_output_declare(): + output_base = [] + for out in self.output_names: + output_base.append(f"int64[{output_shape_str}] {out}") + + return ",".join(output_base) + + def get_tokenizer_ret(): + if onnx_tokenizer_impl == "HfBertTokenizer": + return ",".join(self.output_names) + # different output orders for BertTokenizer and HfBertTokenizer + return "ids,types,mask" + + def build_output_imp(): + if onnx_tokenizer_impl == "HfBertTokenizer": + return "" + + # BertTokenizer has different output dimensions + ret_vars = get_tokenizer_ret().split(",") + ret_vars[1], ret_vars[2] = ret_vars[2], ret_vars[1] + output_str = [] + + for idx, out in enumerate(self.output_names): + output_str.append(f"{out} = Unsqueeze({ret_vars[idx]}, i64_0)") + + return "\n".join(output_str) + + def build_input_declare(): + inputs = f"{input_type_str0}[{input_shape_str0}] {self.input_names[0]}" + return inputs + + def build_unsqueeze(): + if len(input_shape_0) == 1: + return f""" + input_with_batch = Unsqueeze({self.input_names[0]}, i64_0) + """ + else: + return f""" + input_with_batch = Identity({self.input_names[0]}) + """ + + converter_graph = onnx.parser.parse_graph( + f"""\ + {onnx_tokenizer_impl} ({build_input_declare()}) + => ({build_output_declare()}) + {{ + i64_0 = Constant () + {build_unsqueeze()} + {get_tokenizer_ret()} = com.microsoft.extensions.{onnx_tokenizer_impl} (input_with_batch) + {build_output_imp()} + }} + """ + ) + + bert_tokenizer_param = self._tokenizer_param + token_model_attr = [] + + attrs = _vocab_to_dict(bert_tokenizer_param.vocab_or_file) + attrs["strip_accents"] = bert_tokenizer_param.strip_accents + attrs["do_lower_case"] = bert_tokenizer_param.do_lower_case + + for attr in attrs: + token_model_attr.append(onnx.helper.make_attribute(attr, attrs[attr])) + + node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == onnx_tokenizer_impl) + converter_graph.node[node_idx].attribute.extend(token_model_attr) + + return converter_graph + + +class BertTokenizerQADecoder(Step): + def __init__(self, tokenizer_param: TokenizerParam, name: Optional[str] = None): + """ + Brief: + Decode the input_ids to text + Args: + tokenizer_param: some essential info to build a tokenizer. + you can create a TokenizerParam object like: + tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, #vocab is dict or file_path) + name: Optional name of step. Defaults to 'BertTokenizerQADecoder' + """ + super().__init__( + ["start_logits", "end_logits", "input_ids"], ["text"], name) + self._tokenizer_param = tokenizer_param + + def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): + def build_input_declare(): + inputs = [] + for idx, inp in enumerate(self.input_names): + input_type_str_x, input_shape_str_x = self._get_input_type_and_shape_strs(graph, idx) + inputs.append(f"{input_type_str_x}[{input_shape_str_x}] {inp}") + return ",".join(inputs) + + # A unique name for output shape + prefix_ = f'step_{self.step_num}' + output_shape_str = f"_{prefix_}_any_len" + converter_graph = onnx.parser.parse_graph( + f"""\ + tokenizer_decoder ({build_input_declare()}) + => (string[{output_shape_str}] {self.output_names[0]}) + {{ + i64_em = Constant () + i64_1 = Constant () + i64_0 = Constant () + i64_neg1 = Constant () + + s_position = ArgMax({self.input_names[0]}) + e_position = ArgMax({self.input_names[1]}) + ee_position = Add(e_position,i64_1) + u_i64_neg1 = Unsqueeze(i64_neg1, i64_0) + slice_ids= Slice({self.input_names[2]}, s_position, ee_position, i64_neg1) + {self.output_names[0]} = com.microsoft.extensions.BertTokenizerDecoder (slice_ids, i64_em) + }} + """ + ) + + attrs = _vocab_to_dict(self._tokenizer_param.vocab_or_file) + token_model_attr = [] + for attr in attrs: + token_model_attr.append(onnx.helper.make_attribute(attr, attrs[attr])) + + node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == "BertTokenizerDecoder") + converter_graph.node[node_idx].attribute.extend(token_model_attr) + + return converter_graph diff --git a/onnxruntime_extensions/tools/pre_post_processing/utils.py b/onnxruntime_extensions/tools/pre_post_processing/utils.py index 4c0c079a..2f84fcfc 100644 --- a/onnxruntime_extensions/tools/pre_post_processing/utils.py +++ b/onnxruntime_extensions/tools/pre_post_processing/utils.py @@ -21,7 +21,8 @@ def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]): Returns: An onnx.ValueInfoProto that can be used as a new model input. """ - tensor_type = onnx.helper.make_tensor_type_proto(elem_type=data_type, shape=shape) + tensor_type = onnx.helper.make_tensor_type_proto( + elem_type=data_type, shape=shape) return onnx.helper.make_value_info(name, tensor_type) @@ -81,6 +82,32 @@ class IoMapEntry: consumer_idx: int = 0 +@dataclass +class IOEntryValuePreserver: + """ + used to allow an output value to have multiple consumers, + which is only possible when IoMapEntry is used to create those additional connections. + + Generally, a connection consumes an output and an input, then the output is removed from the graph. + This class enabled one-to-many connections by making the other consumers share the same output. + + How this class works: + 1. when the IoMapEntry is created, this class will be created simultaneously. + 2. It records the producer and consumer steps, and the output index of the producer step. + when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output. + 3. when graph merge happens, this class will check if the output is still in the graph, if not, + it will add the output + 4. when consumer step is running, this class will be deactivated and remove output from preserved_list. + """ + + producer: Union["Step", str] = None + consumer: Union["Step", str] = None + # output index from the producer step + producer_idx: int = 0 + is_active: bool = False + output: str = None + + def sanitize_output_names(graph: onnx.GraphProto): """ Convert any usage of invalid characters like '/' and ';' in value names to '_' diff --git a/test/data/bert.vocab b/test/data/bert.vocab new file mode 100644 index 00000000..573fb460 --- /dev/null +++ b/test/data/bert.vocab @@ -0,0 +1 @@ +{"##l": 57, "got": 478, "tech": 1745, "veg": 1951, "##dy": 509, "crazy": 996, "bff": 1996, "said": 467, "dear": 739, "##h": 45, "care": 700, "win": 1891, "##ingfacewithopenmouth": 1967, "inf": 1488, "on": 139, "they": 470, "stay": 972, "##ying": 1318, "photo": 679, "rea": 200, "tra": 1202, "##winkingface": 1678, "grinningface": 1073, "grinningfacewith": 691, "##oz": 1667, "awes": 589, "app": 760, "##ute": 722, "##ingcatface": 485, "program": 1168, "##more": 1136, "att": 1199, "enough": 919, "reason": 1256, "##et": 150, "mine": 947, "##antic": 1623, "12": 1747, "sick": 1819, "catch": 1972, "shut": 849, "possible": 1686, "babe": 1563, "##gu": 1171, "thinking": 1109, "##ction": 1507, "its": 369, "anyone": 1121, "abt": 1217, "hop": 765, "time": 341, "ohk": 1682, "##light": 2000, "suggest": 1365, "exam": 762, "##pensiveface": 1914, "coll": 1005, "ob": 1143, "ins": 1120, "##ingface": 119, "##facewithsteamfromnose": 1757, "am": 135, "reading": 1754, "##ool": 342, "birth": 1295, "nons": 1501, "hur": 520, "lets": 1001, "##oringfood": 1302, "job": 757, "depress": 1054, "##rim": 1447, "grinningfacewithsmilingeyes": 1160, "pass": 1517, "happened": 772, "second": 1850, "2": 8, "google": 1260, "show": 672, "##blowingak": 836, "believe": 902, "yay": 1170, "because": 335, "##ya": 1913, "re": 175, "place": 793, "cha": 1472, "loudlycryingfaceloudlycryingfaceloudlycryingface": 1550, "hu": 786, "open": 1607, "look": 515, "##art": 388, "##joy": 148, "inde": 1525, "ext": 2005, "worried": 1968, "favou": 1339, "war": 1875, "taking": 1746, "vide": 951, "expl": 1016, "##mber": 960, "tht": 1669, "##faceblowingakiss": 1028, "ever": 332, "xd": 1032, "##open": 1600, "welc": 537, "a": 16, "##0": 72, "##k": 52, "##ea": 86, "yea": 253, "deser": 1957, "every": 418, "yr": 1766, "##en": 97, "##reak": 719, "sy": 1578, "ra": 1751, "buddy": 1792, "##ight": 244, "kiss": 891, "##x": 70, "##man": 1297, "##uck": 718, "gon": 779, "##bro": 1107, "everything": 746, "enjoying": 1249, "sim": 1333, "o": 30, "char": 1435, "lol": 409, "20": 1420, "eating": 1889, "##owing": 767, "yeah": 287, "ag": 425, "plan": 843, "ba": 311, "language": 998, "chicken": 1834, "yes": 167, "bl": 726, "definitely": 1378, "search": 1604, "wont": 1881, "'": 5, "##ting": 421, "english": 734, "ton": 1502, "##oose": 1537, "inter": 1172, "watched": 1530, "problems": 1771, "##cryingface": 254, "##lessface": 1362, "relationship": 1134, "##smilingfacewithhearteyes": 988, "##ames": 1301, "father": 1997, "easy": 1926, "##aste": 1383, "underst": 497, "##quint": 621, "##fer": 1157, "team": 1863, "already": 725, "matter": 1195, "playing": 1687, "really": 269, "##faceangry": 1728, "##ut": 134, "cool": 414, "mobile": 1434, "mist": 1516, "family": 1452, "##alo": 1899, "##ick": 706, "##ract": 1673, "##ew": 1211, "song": 635, "beautiful": 1019, "darling": 1322, "whatsapp": 1756, "wasn": 1506, "##ightlyfrowningface": 1707, "rain": 1911, "ye": 1151, "##gn": 936, "##con": 1328, "##ig": 377, "problem": 684, "##ly": 132, "##ud": 264, "aww": 1186, "opin": 1852, "##ache": 1706, "another": 978, "mov": 411, "##lling": 1605, "shy": 1987, "##alk": 202, "##st": 109, "##isappoint": 453, "##oud": 270, "##ould": 239, "##fect": 1096, "##rollingeyes": 1853, "waste": 1820, "wanna": 518, "supp": 1158, "##ors": 1524, "##ese": 886, "##redface": 1950, "vo": 1395, "stud": 628, "books": 1851, "##ild": 1490, "com": 451, "##ps": 1066, "same": 472, "sister": 1622, "yu": 1715, "##ared": 1520, "##ap": 1095, "##hand": 977, "##haha": 1290, "should": 439, "heal": 1816, "##ries": 1943, "ho": 745, "vis": 1890, "different": 1442, "##age": 469, "those": 1206, "intell": 920, "heart": 599, "##grinningface": 1076, "##iness": 1183, "wan": 185, "##ber": 804, "##ast": 452, "heard": 1511, "funny": 309, "since": 1518, "facewith": 995, "del": 1528, "hear": 727, "wri": 1645, "##ag": 587, "##big": 810, "start": 664, "woman": 1773, "saw": 1494, "u": 36, "think": 286, "he": 314, "could": 668, "te": 617, "cl": 580, "##hands": 1625, "you": 80, "nice": 402, "##ell": 142, "opt": 1999, "##rong": 551, "coffee": 1499, "s": 34, "say": 262, "happ": 250, "after": 768, "this": 236, "co": 209, "stop": 591, "miss": 473, "##wn": 956, "fall": 1880, "intellig": 938, "dont": 328, "cou": 818, "##ers": 360, "bat": 1656, "##ish": 352, "thnx": 1925, "smilingface": 1948, "jo": 351, "grinningfacewithsweat": 1010, "facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 644, "cr": 1031, "others": 1838, "##au": 237, "important": 1760, "##ul": 437, "bored": 1141, "##mes": 1059, "##withtearsofjoy": 157, "##li": 1269, "##ash": 1270, "dec": 1809, "female": 1927, "##ited": 1192, "##v": 59, "##eak": 744, "phone": 763, "dist": 1104, "donapost": 1876, "bro": 456, "##ering": 1254, "##lc": 524, "##xt": 566, "mistake": 1909, "##ball": 1541, "word": 1234, "write": 1396, "##llywood": 1471, "pics": 1382, "though": 845, "##blem": 677, "age": 1224, "##grinningfacewithbigeyes": 1439, "##gether": 1508, "unt": 1938, "mil": 1797, "the": 107, "si": 1219, "##my": 1893, "##ark": 1373, "wait": 562, "yep": 1340, "eng": 1337, "wo": 1579, "##ven": 391, "work": 441, "sorry": 330, "chocolate": 1609, "sat": 1818, "##llyw": 1457, "too": 256, "##loudlycryingface": 356, "##ine": 285, "##ding": 641, "ab": 192, "pa": 1764, "remember": 1135, "nah": 1450, "had": 558, "##ress": 502, "##ate": 220, "why": 159, "##iff": 925, "##unamusedface": 1449, "##ari": 1204, "wife": 1908, "ty": 707, "whom": 1381, "hell": 501, "po": 573, "ev": 1731, "both": 1009, "idea": 1123, "##idd": 1570, "lo": 138, "##usedface": 834, "mach": 1817, "##you": 1445, "##lieved": 863, "listen": 808, "##r": 47, "prett": 733, "company": 1921, "hon": 1011, "##av": 968, "excited": 1642, "##ual": 1389, "##ween": 1697, "hours": 1774, "##ered": 1505, "##ither": 985, "truth": 1209, "##at": 82, "hate": 318, "en": 353, "loudlycryingface": 652, "ac": 893, "li": 127, "no": 92, "true": 881, "boyfriend": 882, "##ect": 799, "contact": 1593, "giving": 1708, "asking": 686, "##brokenheartbrokenheart": 1596, "##ie": 279, "##ed": 114, "bec": 312, "please": 339, "##im": 407, "ugly": 1367, "when": 367, "sad": 340, "##ww": 830, "##aught": 1681, "talk": 214, "deep": 1930, "##ons": 861, "missing": 1360, "gen": 1105, "sub": 1699, "##ingcatfacewithsmilingeyes": 1294, "##oot": 1329, "never": 435, "##ean": 252, "chill": 1426, "cho": 933, "gi": 1046, "intelligence": 1313, "going": 382, "##savoringfood": 1306, "el": 729, "into": 1529, "world": 903, "home": 598, "##ss": 337, "har": 775, "artificial": 1551, "##ral": 1821, "##c": 61, "##smilingeyes": 327, "##ation": 399, "alright": 1398, "##ways": 458, "##inkingfacewithtongue": 1097, "reme": 1067, "##or": 103, "okay": 368, "trust": 1292, "##ssible": 1476, "##rimac": 1610, "##nect": 1717, "sweet": 647, "z": 41, "lat": 912, "teach": 997, "through": 1854, "inv": 1558, "annoy": 667, "##facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 1326, "gl": 773, "short": 1966, "##unch": 1228, "thought": 651, "fan": 1612, "cu": 1308, "##m": 50, "##ror": 1892, "##rent": 1984, "help": 525, "##ath": 636, "choc": 1349, "##gram": 1140, "dude": 922, "##set": 1108, "##confoundedface": 1976, "##end": 168, "interested": 1176, "##umb": 361, "share": 1092, "cric": 1749, "pun": 1936, "##ger": 1343, "clo": 1394, "sex": 790, "well": 322, "yup": 809, "c": 18, "fo": 1099, "crush": 1401, "##est": 240, "weekend": 1482, "future": 1421, "blue": 1883, "drink": 1314, "alre": 723, "##act": 703, "##ense": 764, "##ital": 2003, "##um": 251, "##oint": 383, "bet": 481, "her": 569, "him": 821, "##lt": 1287, "hmm": 499, "sarc": 1519, "um": 1752, "agree": 1375, "it": 108, "whatever": 1126, "up": 274, "##ult": 1347, "ill": 1486, "##oooo": 1127, "change": 1040, "##f": 60, "run": 1961, "##ght": 204, "for": 162, "part": 748, "##facewithtearsofjoyfacewithtearsofjoy": 357, "##ur": 273, "##other": 952, "##cept": 1184, "##ingfacewithsunglasses": 1655, "des": 1039, "count": 1154, "even": 482, "feeling": 506, "##int": 275, "langu": 874, "cute": 962, "obvious": 1440, "jokes": 1266, "##ict": 1549, "hour": 1317, "give": 423, "##rect": 1053, "pri": 1577, "lear": 954, "real": 616, "fav": 708, "stand": 1900, "rem": 1671, "##ling": 1026, "iam": 841, "btw": 1626, "poutingface": 1100, "eat": 833, "bot": 873, "spec": 1083, "##ex": 1569, "let": 420, "##not": 1522, "##io": 1668, "mu": 354, "##grinningsquintingface": 1085, "cuz": 1759, "##redheart": 1320, "##per": 561, "##ning": 690, "##rom": 325, "does": 552, "##ac": 701, "##le": 189, "##withtong": 713, "guy": 755, "dis": 942, "##facewith": 980, "having": 1004, "weird": 1458, "##ens": 930, "says": 1683, "stup": 374, "##rown": 872, "bf": 1350, "##fast": 1982, "##ou": 78, "yourself": 792, "between": 1724, "##steamfrom": 1013, "##ity": 666, "ignore": 1572, "link": 1370, "rom": 1410, "##q": 66, "##press": 634, "last": 867, "confused": 1886, "soo": 1225, "foot": 1763, "need": 385, "##ack": 1213, "sw": 532, "##ions": 1232, "irritating": 798, "n": 29, "##id": 140, "##pose": 1321, "favorite": 1251, "forget": 1415, "either": 1552, "##on": 88, "watching": 1161, "else": 782, "non": 1733, "exper": 1700, "ly": 1860, "##ci": 1220, "aren": 1207, "check": 1162, "lik": 1231, "appre": 1905, "ma": 1861, "college": 1223, "##gs": 1334, "10": 999, "loud": 1553, "game": 806, "shall": 1407, "picture": 717, "wor": 268, "##ating": 584, "##tiredface": 1954, "online": 1310, "import": 1629, "ce": 1835, "##ete": 1288, "da": 1351, "than": 211, "party": 1385, "die": 1794, "##ah": 1144, "state": 1741, "##ia": 1753, "breakfast": 1991, "##w": 63, "might": 1417, "disappointed": 1587, "fu": 188, "profile": 1994, "##fl": 1633, "##sof": 151, "news": 1637, "busy": 819, "seem": 970, "gift": 1777, "swee": 605, "##b": 62, "##an": 91, "##zz": 1392, "emoj": 2015, "smilingcatfacewithhearteyes": 1684, "typ": 1481, "dr": 1043, "x": 39, "tw": 907, "fact": 1786, "studying": 1543, "##ke": 122, "serious": 665, "##as": 181, "photos": 1869, "##8": 68, "##by": 496, "whats": 613, "##oto": 673, "day": 365, "##steamfromnose": 1035, "buy": 1542, "first": 614, "ms": 1371, "find": 660, "sur": 1590, "d": 19, "proba": 1436, "haven": 1255, "##ition": 1824, "##out": 176, "##ouch": 1737, "lot": 649, "##lievedface": 877, "##ated": 753, "doesn": 1002, "school": 1179, "##smil": 247, "what": 102, "movies": 1062, "##ire": 1353, "cont": 824, "yo": 1567, "brokenheart": 1947, "spe": 967, "##grinningfacewithsmilingeyes": 1248, "##bb": 1696, "get": 242, "##ty": 1456, "won": 610, "broke": 832, "kissingcatface": 1594, "where": 302, "resp": 1408, "supposed": 1833, "##ey": 191, "baby": 595, "games": 1732, "ap": 1977, "tiredface": 1872, "##ats": 1418, "uh": 1203, "##ace": 93, "smart": 1149, "again": 510, "##ally": 417, "kidding": 989, "##steamfromn": 1014, "sleeping": 1538, "his": 1815, "##friend": 860, "red": 955, "##ose": 663, "myself": 969, "ask": 308, "##key": 1581, "wrong": 583, "##ent": 291, "##our": 492, "##loor": 1744, "wat": 398, "pre": 495, "##oudly": 294, "pic": 370, "prof": 1428, "##ingcatfacewith": 858, "##if": 430, "##lock": 964, "friendship": 1601, "till": 1677, "##kk": 1078, "##ply": 559, "##conf": 1675, "most": 761, "address": 1726, "being": 576, "country": 1441, "##way": 355, "##ough": 698, "belong": 1887, "oh": 249, "##able": 1052, "con": 386, "jud": 1978, "ne": 203, "##ange": 898, "##ver": 182, "wtf": 1812, "low": 1621, "##ri": 145, "##ip": 608, "##cing": 1822, "proper": 1969, "##ming": 1716, "better": 553, "##u": 43, "intern": 1709, "##n": 58, "##no": 1191, "god": 1216, "not": 120, "sit": 1252, "an": 106, "##is": 123, "only": 415, "back": 529, "gonna": 795, "##relievedface": 1355, "##sw": 392, "artific": 1532, "ter": 1713, "tel": 1990, "##ial": 712, "sec": 892, "##ouse": 1205, "call": 397, "##ingeyes": 319, "##reen": 1780, "##isapp": 448, "##bu": 1802, "fore": 1369, "##gle": 1113, "##ric": 1178, "play": 579, "human": 692, "##o": 49, "##ich": 378, "##ign": 1374, "art": 1103, "imag": 1807, "sl": 359, "wouldn": 1989, "send": 333, "expect": 1672, "##ue": 412, "##apo": 981, "worst": 1649, "##de": 876, "must": 1106, "##beamingfacewithsmilingeyesbeamingfacewithsmilingeyes": 1451, "break": 778, "there": 381, "thing": 578, "bad": 315, "useless": 1906, "grinningcatfacewithsmilingeyes": 1510, "4": 10, "choose": 1845, "##6": 75, "##alent": 1680, "##vel": 927, "everyone": 908, "##ther": 384, "forever": 1544, "##za": 1423, "b": 17, "very": 280, "totally": 1711, "f": 21, "al": 201, "mean": 278, "cur": 1659, "don": 128, "e": 20, "sh": 217, "smil": 462, "##ever": 916, "math": 1554, "special": 1377, "propose": 1988, "tell": 221, "##any": 1598, "##estion": 500, "##s": 51, "##ant": 568, "##grinn": 447, "av": 1761, "was": 229, "type": 1093, "went": 1714, "pizza": 1514, "without": 1198, "inst": 1559, "mother": 1917, "are": 112, "##ust": 178, "##ightly": 826, "someone": 588, "pain": 948, "married": 1348, "##les": 1503, "##earsof": 155, "##ving": 531, "perfect": 1259, "faceblowingakiss": 1783, "##own": 720, "joke": 650, "conf": 899, "water": 1650, "##ious": 498, "##az": 731, "##pression": 1272, "##irk": 1465, "looks": 1478, "money": 921, "moment": 1826, "coup": 1919, "##cl": 1983, "##oudlycryingface": 297, "spo": 2008, "##ually": 597, "exc": 1116, "##ti": 848, "##ame": 241, "##ensiveface": 1431, "##app": 394, "ex": 306, "will": 226, "nor": 1424, "##iss": 321, "pur": 1239, "##di": 464, "w": 38, "##ng": 1091, "dinner": 1036, "##j": 64, "##tim": 1101, "answ": 548, "##vious": 1264, "is": 118, "##ied": 1402, "##blowing": 835, "child": 1827, "subject": 1873, "great": 511, "hehe": 1061, "anything": 556, "our": 883, "software": 1574, "##ic": 180, "hum": 618, "enjoy": 742, "wom": 1182, "##disappointedface": 702, "listening": 1789, "cause": 950, "living": 1942, "free": 959, "colle": 1196, "lunch": 1500, "##eal": 923, "relation": 1064, "##bly": 1390, "##hhhh": 1755, "##rible": 1944, "their": 1114, "wanted": 1325, "j": 25, "turn": 1758, "##bo": 750, "by": 329, "live": 802, "##sy": 1841, "mind": 685, "boys": 1904, "neither": 1561, "##3": 77, "##son": 1047, "sleepy": 1743, "who": 259, "guess": 676, "##pressionlessface": 1366, "##right": 1131, "star": 1769, "far": 1342, "##loudlycryingfaceloudlycryingface": 493, "##faceblowingakissfaceblowingakiss": 1573, "self": 1029, "m": 28, "before": 1079, "##pect": 1147, "##ain": 307, "##arc": 1448, "always": 474, "upset": 1185, "##1": 65, "##iddle": 1776, "##rinn": 255, "[PAD]": 0, "fuck": 379, "##sweat": 697, "forg": 842, "tot": 1536, "she": 540, "working": 1081, "##eeee": 1670, "##ub": 1282, "have": 171, "##lie": 565, "3": 9, "##ile": 756, "col": 1526, "read": 896, "##bl": 693, "##gg": 740, "student": 1651, "around": 1286, "##eam": 405, "ga": 1836, "##smilingfacewithsmilingeyes": 1427, "##ower": 1915, "tur": 1261, "become": 1617, "gu": 455, "right": 428, "idk": 1882, "voice": 1533, "marry": 1406, "[SEP]": 3, "thr": 1504, "##ak": 334, "much": 362, "##and": 303, "little": 889, "ahead": 1956, "##ill": 184, "my": 126, "eas": 1644, "slightlysmilingface": 1477, "##blowingakiss": 840, "own": 1197, "confus": 1409, "hand": 1130, "##am": 230, "##face": 98, "topic": 1065, "##bot": 828, "##ourse": 800, "##ach": 749, "##sav": 1298, "##du": 1422, "##osed": 1811, "##but": 1230, "do": 121, "least": 1592, "##iew": 1918, "cryingface": 853, "##outh": 979, "story": 1115, "##ingsquintingface": 728, "##ali": 1964, "##ere": 177, "per": 817, "[CLS]": 2, "lost": 504, "film": 1796, "##ow": 101, "##ret": 1084, "##ple": 1023, "mon": 632, "##ood": 164, "##isappointedface": 489, "##food": 1283, "guys": 1414, "brain": 1585, "things": 704, "comm": 1006, "understand": 521, "##skintone": 1856, "##quintingfacewithtongue": 1725, "depend": 1429, "tr": 440, "more": 422, "##ee": 136, "ear": 911, "put": 1765, "##apost": 1674, "under": 475, "differe": 1189, "##orm": 1941, "##eart": 271, "##ause": 305, "new": 593, "irrit": 640, "di": 1118, "redheart": 1586, "interest": 710, "awesome": 596, "secret": 1246, "been": 602, "0": 6, "##mouth": 1262, "forgot": 1531, "wearyface": 1632, "many": 695, "##inkingface": 694, "food": 993, "hair": 1487, "dog": 1739, "alone": 541, "op": 1090, "best": 479, "aw": 416, "fr": 689, "##ba": 714, "th": 83, "earth": 1871, "##ound": 582, "pr": 1070, "answer": 609, "##ore": 406, "##ry": 125, "tri": 1693, "##na": 371, "##outingface": 577, "##d": 56, "sound": 1177, "days": 1017, "nobody": 1242, "##es": 95, "pe": 494, "fake": 1112, "##eep": 350, "##facean": 1277, "started": 1430, "br": 934, "defin": 1150, "see": 300, "##der": 984, "##ure": 431, "##use": 1735, "off": 454, "##gin": 1840, "pick": 1493, "your": 144, "##ol": 878, "rest": 1603, "liked": 1993, "##un": 444, "honey": 1992, "##nt": 338, "##ise": 901, "##poutingfacepoutingface": 1072, "[MASK]": 4, "ser": 555, "calling": 1723, "then": 258, "int": 450, "k": 26, "##cent": 1464, "compliment": 1955, "ok": 154, "ar": 592, "100": 2013, "##ream": 785, "##oring": 620, "##ention": 1635, "wh": 85, "##qu": 517, "##su": 1521, "dreams": 1727, "##cat": 433, "hot": 1201, "gott": 1804, "lea": 514, "sn": 1734, "friends": 508, "hurts": 1235, "##one": 223, "choice": 1619, "haa": 2001, "engl": 715, "msg": 1608, "smile": 1584, "##openmouth": 1729, "correct": 1341, "##ingfacewith": 222, "cry": 631, "##ful": 926, "##oy": 133, "##el": 1463, "kill": 1692, "i": 24, "##asses": 1602, "##lf": 316, "ph": 434, "plz": 857, "here": 436, "fir": 603, "super": 944, "haha": 263, "##down": 1842, "used": 1460, "pretty": 743, "sometimes": 1167, "##ular": 1868, "add": 992, "hahaha": 659, "##ook": 403, "##pt": 716, "course": 885, "frnd": 1466, "tonight": 1611, "det": 1858, "##earyface": 1139, "knew": 1319, "pls": 961, "bud": 1575, "together": 1557, "cant": 888, "impress": 1701, "gotta": 1959, "r": 33, "##row": 491, "def": 1030, "sun": 1866, "##ship": 1069, "feelings": 1257, "be": 124, "se": 373, "actually": 670, "##steam": 994, "na": 619, "##ch": 143, "exactly": 1180, "angryface": 1703, "##ance": 781, "luck": 913, "##most": 1801, "match": 1884, "##rit": 465, "##lfriend": 574, "years": 1236, "disappointedface": 788, "thumbs": 1124, "huh": 1293, "month": 1479, "wish": 711, "##withtearsof": 156, "micro": 1470, "##ake": 282, "list": 732, "lame": 1296, "all": 299, "comp": 852, "##disappointedfacedisappointedface": 1276, "something": 393, "next": 844, "h": 23, "facewithsteamfromnose": 1467, "##with": 117, "fa": 372, "nope": 678, "1": 7, "probably": 1513, "litt": 887, "almost": 1878, "thou": 484, "suck": 1806, "year": 675, "connect": 1808, "angry": 445, "hard": 906, "microsoft": 1515, "creat": 990, "also": 461, "##wearyface": 1864, "##gr": 931, "hm": 1660, "##ion": 212, "over": 730, "pleasure": 1618, "##ha": 187, "tomor": 811, "##ightlysmilingface": 1049, "wa": 507, "happen": 533, "exams": 1303, "##are": 639, "##iment": 1867, "bollywood": 1848, "welcome": 542, "##inn": 238, "##ear": 115, "imp": 1208, "depends": 1565, "##its": 1491, "##ds": 784, "clearly": 1653, "##ce": 87, "scared": 1952, "blocked": 2014, "getting": 770, "##il": 160, "broken": 1393, "sry": 1962, "bo": 557, "##old": 1800, "suff": 1583, "few": 1384, "dare": 1813, "gr": 966, "try": 477, "boy": 572, "at": 295, "travel": 1791, "##ing": 81, "##sc": 1779, "##ank": 1843, "du": 1630, "cook": 1748, "##glasses": 1636, "just": 194, "man": 426, "un": 600, "##ture": 699, "##ct": 231, "mood": 854, "meant": 1346, "hurt": 560, "##ond": 1125, "umm": 1613, "##ot": 257, "##poutingface": 736, "##ooo": 1425, "##us": 288, "##fore": 1060, "still": 594, "ir": 604, "anyway": 1345, "yester": 1311, "indian": 1399, "##5": 76, "##ister": 1539, "seen": 1212, "once": 1489, "mor": 946, "sch": 1071, "##fac": 1229, "norm": 1975, "##hh": 468, "we": 219, "##ken": 615, "mak": 688, "book": 1045, "words": 1285, "##ost": 1372, "about": 208, "p": 31, "come": 419, "them": 625, "##ird": 1274, "tried": 1798, "sol": 1963, "ice": 1998, "##to": 1778, "ha": 105, "##bigeyes": 815, "friend": 293, "body": 1830, "##po": 550, "##angu": 859, "chocol": 1416, "watch": 516, "badly": 1879, "ch": 225, "##ouble": 1546, "awww": 1702, "poor": 1258, "##ible": 1263, "##al": 141, "talking": 390, "##ff": 538, "thank": 234, "hurting": 1810, "su": 233, "with": 186, "cat": 803, "point": 1169, "str": 1548, "sleep": 463, "##yy": 687, "office": 1438, "##thers": 1647, "hy": 1814, "##ang": 534, "##pend": 1323, "##zy": 943, "dar": 957, "##ilm": 1357, "hope": 812, "clear": 1174, "move": 1344, "and": 152, "hor": 1387, "ti": 284, "res": 869, "##ward": 1615, "suffering": 1793, "sill": 1839, "bed": 909, "chic": 1784, "dark": 2011, "asked": 831, "romantic": 1691, "kind": 630, "##lieve": 586, "soon": 949, "##ject": 1164, "numb": 567, "##ble": 404, "thats": 645, "favourite": 1386, "ohhh": 1080, "im": 324, "emot": 1689, "know": 153, "##facesavoringfood": 1896, "whole": 1527, "india": 814, "##umbs": 879, "making": 1133, "lie": 1299, "class": 1361, "yaa": 1591, "dumb": 847, "sal": 1912, "##bile": 1391, "hw": 1932, "amaz": 870, "kn": 146, "##gh": 166, "##ite": 543, "##winkingfacewithtongue": 1718, "ah": 1021, "##hip": 1589, "tat": 1979, "idiot": 1094, "##ingfacewithhearteyes": 637, "##perseveringface": 1970, "fant": 1931, "looking": 1063, "##itely": 1338, "##pp": 174, "trying": 905, "follow": 1627, "##ad": 158, "which": 400, "##be": 607, "##urb": 1275, "##llow": 1356, "shit": 1358, "##ics": 1721, "lovely": 1901, "sense": 1012, "##ice": 304, "text": 986, "##i": 55, "##rite": 880, "##where": 1363, "oo": 1455, "##ld": 190, "message": 971, "##iz": 794, "##from": 1003, "##guish": 1628, "car": 1181, "devel": 1722, "##ix": 1736, "now": 218, "did": 243, "gir": 331, "dance": 1267, "songs": 1480, "id": 721, "##mor": 738, "fam": 1142, "person": 575, "##irkingface": 1534, "##ph": 982, "experience": 1888, "left": 1111, "qu": 395, "##os": 871, "fast": 1068, "##augh": 638, "##ar": 170, "sugg": 1145, "video": 1034, "goes": 2002, "out": 432, "coff": 1492, "intelligent": 1652, "question": 530, "##wh": 895, "##butrelievedface": 1924, "yess": 1803, "fucking": 1153, "facewithtearsofjoyfacewithtearsofjoy": 654, "ru": 1665, "joking": 1496, "##sed": 1894, "relievedface": 1335, "favor": 1188, "##ive": 317, "##fe": 344, "leaving": 1829, "robot": 928, "ve": 627, "lone": 780, "waiting": 965, "catfacewithtearsofjoy": 1495, "kinda": 1540, "met": 1599, "prom": 1664, "may": 554, "sexy": 1953, "sor": 323, "bye": 547, "beamingfacewithsmilingeyes": 752, "any": 246, "##ail": 935, "room": 1388, "interesting": 1148, "found": 1795, "##ser": 1940, "girlfriend": 581, "week": 850, "dp": 1278, "dat": 1762, "bel": 1582, "##ong": 260, "thanks": 358, "##gry": 387, "tea": 855, "gud": 1576, "laugh": 682, "less": 1663, "doing": 490, "trouble": 1847, "yaar": 1194, "##ll": 96, "##bs": 1556, "##pe": 366, "##loudlycryingfaceloudlycryingfaceloudlycryingfaceloudlycryingface": 1221, "unamusedface": 1405, "absolute": 1857, "that": 131, "us": 528, "##withtongue": 759, "##sk": 1568, "##ro": 165, "disturb": 1364, "loudlycryingfaceloudlycryingface": 1704, "tru": 1075, "pers": 564, "down": 894, "##lo": 480, "##riend": 228, "fine": 446, "gf": 807, "fun": 245, "g": 22, "meet": 571, "would": 408, "take": 611, "t": 35, "tv": 1444, "##ried": 789, "##ink": 227, "mad": 570, "fl": 958, "##z": 67, "##ush": 865, "##less": 655, "##sol": 1782, "so": 111, "##9": 74, "like": 173, "girl": 443, "##ional": 1902, "##vice": 1555, "##ome": 281, "has": 754, "ann": 662, "done": 777, "chat": 487, "v": 37, "fem": 1705, "full": 1033, "nothing": 345, "##up": 266, "8": 14, "beaut": 914, "told": 822, "yet": 890, "mic": 1454, "wear": 1614, "single": 1087, "bea": 751, "##ort": 683, "walk": 1694, "##hhh": 1949, "meaning": 1244, "##aa": 771, "refer": 1916, "ca": 310, "bus": 671, "##the": 1895, "make": 438, "pl": 196, "bir": 1117, "rude": 523, "able": 1616, "##severingface": 1712, "yesterday": 1315, "other": 646, "hey": 546, "bc": 1307, "q": 32, "ign": 1000, "##he": 787, "##oom": 1024, "sent": 875, "conversation": 1247, "plea": 326, "usually": 2009, "l": 27, "##iv": 1767, "city": 1368, "dri": 1960, "stupid": 375, "girls": 1227, "side": 1862, "9": 15, "##red": 539, "mess": 661, "##ment": 851, "##g": 44, "repl": 1560, "##ist": 483, "##ous": 389, "lu": 1662, "##outingcatface": 1844, "##rowningface": 1138, "leave": 674, "top": 900, "ur": 298, "5": 11, "birthday": 1497, "liar": 1898, "om": 1934, "laughing": 1175, "##sp": 1981, "milk": 1995, "##ir": 198, "as": 210, "##ter": 283, "mo": 207, "bi": 1237, "##re": 84, "bitch": 1018, "people": 522, "##ount": 1679, "##cryingfacecryingface": 1233, "cricket": 1855, "seriously": 1042, "goo": 1056, "idi": 1077, "study": 839, "convers": 1226, "##ay": 113, "good": 183, "##ages": 1245, "winkingface": 1799, "piss": 1935, "##phone": 1974, "trash": 2007, "yours": 1166, "computer": 1469, "speak": 797, "headache": 1790, "worry": 1404, "half": 1580, "such": 924, "depressed": 1305, "##4": 73, "##th": 89, "##ift": 1271, "pictures": 1638, "##ings": 401, "thn": 1354, "head": 963, "##slightlysmilingface": 1923, "wht": 953, "hi": 747, "these": 1027, "loves": 1300, "ready": 1475, "act": 519, "fool": 1443, "##smilingface": 918, "engine": 1512, "love": 197, "sending": 1828, "ohh": 696, "health": 1910, "##line": 1163, "che": 829, "6": 12, "coz": 1128, "stuff": 1243, "min": 1218, "##ard": 941, "##ide": 545, "##ny": 296, "ll": 429, "dead": 1453, "##earch": 1547, "ro": 563, "smilingfacewithsmilingeyes": 1187, "tomorrow": 823, "##sun": 1597, "indeed": 1643, "fin": 1089, "##ck": 199, "if": 313, "cra": 791, "fee": 289, "cor": 1238, "##ingcatfacewithhearteyes": 1330, "##t": 48, "##ather": 1082, "##eamingfacewithsmilingeyes": 505, "number": 606, "appreci": 1929, "##ustr": 1986, "glad": 827, "long": 776, "##sure": 1462, "late": 1190, "##ved": 1897, "##a": 42, "of": 149, "##body": 846, "while": 1523, "##ass": 624, "facewithtearsofjoy": 265, "##sever": 1698, "explain": 1098, "##ically": 1805, "##times": 1155, "indi": 642, "##ory": 825, "##sl": 1214, "two": 1200, "seems": 1483, "movie": 486, "##p": 54, "##day": 349, "##mm": 292, "##ai": 1939, "lonely": 805, "##uture": 1411, "grinningcatface": 1412, "sm": 737, "ey": 1352, "compl": 904, "in": 129, "big": 1050, "##er": 90, "crying": 1086, "##uu": 975, "music": 1055, "high": 1639, "##em": 1241, "failed": 1885, "##y": 46, "can": 161, "away": 976, "##ryingface": 235, "##grinningfacewithsweat": 1304, "st": 163, "##ves": 864, "sp": 471, "facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 1688, "##hearteyes": 549, "##ra": 376, "##ree": 585, "means": 656, "le": 709, "##frowningface": 1432, "except": 1945, "disappoint": 1379, "##ess": 276, "##med": 1606, "##yyyy": 1624, "##time": 1312, "reply": 612, "happy": 336, "smilingfacewithhearteyes": 932, "##ys": 1980, "##2": 69, "honest": 1484, "sing": 680, "tired": 1129, "##inner": 1007, "fell": 1588, "sounds": 1710, "way": 544, "##ence": 643, "##rr": 1380, "##ize": 1640, "personal": 1770, "##ert": 1781, "men": 1837, "exact": 1110, "##ery": 267, "sea": 1937, "life": 424, "##ince": 1215, "comput": 1376, "but": 193, "me": 99, "nat": 1631, "hello": 937, "##so": 137, "nonsense": 1545, "##ies": 536, "were": 601, "##tt": 442, "later": 1038, "##ushedface": 1922, "fe": 735, "cannot": 1719, "mar": 626, "pro": 347, "absol": 1825, "##erest": 705, "missed": 1122, "gone": 1676, "y": 40, "damn": 974, "boring": 862, "##ve": 104, "de": 396, "telling": 1336, "obviously": 1654, "grinningfacewithbigeyes": 1088, "using": 1685, "reg": 1877, "##vers": 983, "soft": 1446, "##ored": 973, "prefer": 1787, "questions": 1102, "hav": 1865, "sha": 648, "acc": 1400, "##mil": 205, "mr": 1933, "complete": 1870, "women": 1958, "learn": 1250, "##gl": 590, "maybe": 866, "hang": 1661, "##ind": 301, "bit": 658, "##ary": 1058, "##nd": 272, "##e": 53, "microsof": 1498, "##cess": 1634, "rec": 1720, "kid": 783, "jok": 838, "date": 1279, "ass": 1284, "eyes": 1690, "##nce": 1695, "called": 1193, "##00": 1646, "##ensive": 1331, "happens": 2010, "##eat": 527, "mom": 1403, "mat": 856, "##roll": 1289, "##kenheart": 945, "##smilingfacewithhearteyessmilingfacewithhearteyes": 2012, "##pl": 449, "how": 169, "##een": 1268, "##7": 71, "##catface": 460, "wonder": 1413, "##ical": 1173, "lucky": 1595, "business": 1832, "times": 1146, "##edface": 348, "name": 459, "hel": 512, "wt": 915, "saying": 741, "##ith": 100, "isn": 1265, "annoying": 758, "##heart": 346, "##ingfacewithsun": 1648, "[UNK]": 1, "##reat": 413, "some": 216, "scary": 1971, "##skint": 1775, "cheat": 1015, "dream": 1022, "##oo": 116, "##it": 179, "grinn": 363, "coming": 1041, "go": 172, "didn": 526, "##eek": 816, "brother": 1571, "house": 1309, "##wor": 1823, "wi": 1473, "silly": 1874, "sk": 1273, "7": 13, "##ude": 410, "##all": 1281, "post": 1419, "##oundedface": 1788, "emo": 1849, "##sm": 1535, "bt": 1044, "plans": 1920, "dam": 897, "works": 1903, "feel": 290, "##thing": 195, "##ific": 1074, "hmmm": 1222, "thumbsup": 1485, "came": 1658, "amazing": 929, "##se": 130, "ug": 1327, "cold": 1562, "##led": 1768, "block": 1051, "##ft": 769, "cheated": 1468, "pu": 1750, "##ware": 1433, "end": 1037, "##ations": 1846, "youre": 1253, "##nder": 466, "mus": 987, "okk": 1965, "##lly": 206, "##in": 79, "wow": 503, "##beamingfacewithsmilingeyes": 766, "anymore": 1159, "##ale": 1137, "##rimacingface": 1620, "##self": 535, "today": 488, "##ge": 277, "made": 724, "rel": 796, "loved": 1291, "grinningsquintingface": 1132, "one": 248, "from": 364, "##eary": 917, "##thumbs": 1738, "each": 1859, "develop": 1730, "or": 320, "beach": 1772, "af": 657, "##grinningsquintingfacegrinningsquintingface": 1907, "chatting": 991, "ref": 1985, "absolutely": 1928, "em": 820, "##log": 1459, "want": 215, "didnt": 1359, "dnt": 1316, "to": 94, "bu": 837, "##brokenheart": 1210, "expressionlessface": 1946, "##ingfacewithsmilingeyes": 343, "##grinningfacewith": 884, "ai": 774, "old": 868, "##ket": 940, "##ead": 622, "##ab": 1057, "fail": 1152, "##op": 261, "##pped": 1742, "winkingfacewithtongue": 1666, "##eyes": 232, "ya": 427, "night": 633, "morning": 1020, "##me": 110, "##sh": 681, "quite": 1564, "##rowningfacewithopenmouth": 1973, "##quintingface": 623, "knw": 1324, "sc": 653, "##facewithtearsofjoy": 213, "##irt": 2004, "form": 1740, "nt": 1332, "use": 801, "diff": 939, "tak": 1240, "nd": 1280, "##om": 476, "makes": 1025, "##uff": 1119, "##orn": 1165, "join": 1509, "gener": 1641, "exist": 2006, "##ness": 1474, "peop": 513, "sure": 457, "##witht": 147, "##arent": 1831, "sa": 380, "keep": 813, "##ily": 1048, "piz": 1461, "face": 224, "##ience": 1156, "##ump": 1785, "##amusedface": 910, "test": 1566, "##iful": 1008, "##ady": 669, "##ates": 1397, "bas": 1657, "##astic": 1437, "ad": 629} \ No newline at end of file diff --git a/test/data/bert_qa_decoder_base.onnx b/test/data/bert_qa_decoder_base.onnx new file mode 100644 index 00000000..963e678d Binary files /dev/null and b/test/data/bert_qa_decoder_base.onnx differ diff --git a/test/data/hfbert.vocab b/test/data/hfbert.vocab new file mode 100644 index 00000000..4b880397 --- /dev/null +++ b/test/data/hfbert.vocab @@ -0,0 +1 @@ +{"##gle": 1113, "##gram": 1140, "##1": 65, "##ense": 764, "##nder": 466, "[CLS]": 2, "##augh": 638, "ah": 1021, "ever": 332, "best": 479, "##ish": 352, "10": 999, "ra": 1751, "feeling": 506, "colle": 1196, "men": 1837, "makes": 1025, "z": 41, "day": 365, "office": 1438, "##witht": 147, "google": 1260, "aw": 416, "##u": 43, "mak": 688, "teach": 997, "ass": 1284, "travel": 1791, "cook": 1748, "8": 14, "sorry": 330, "q": 32, "act": 519, "##one": 223, "mother": 1917, "exist": 2006, "##nd": 272, "face": 224, "gone": 1676, "hours": 1774, "##heart": 346, "##ce": 87, "##ong": 260, "something": 393, "check": 1162, "##to": 1778, "tel": 1990, "facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 644, "fu": 188, "doing": 490, "##ook": 403, "##redface": 1950, "lone": 780, "sk": 1273, "##oundedface": 1788, "pun": 1936, "make": 438, "pl": 196, "brain": 1585, "##it": 179, "listening": 1789, "##ur": 273, "##sc": 1779, "##facewithtearsofjoyfacewithtearsofjoy": 357, "dead": 1453, "yep": 1340, "dinner": 1036, "honest": 1484, "absolute": 1857, "coff": 1492, "cha": 1472, "bud": 1575, "wanna": 518, "loudlycryingfaceloudlycryingfaceloudlycryingface": 1550, "##up": 266, "ba": 311, "##ght": 204, "great": 511, "y": 40, "##pressionlessface": 1366, "##ried": 789, "have": 171, "food": 993, "po": 573, "mind": 685, "agree": 1375, "another": 978, "nonsense": 1545, "##down": 1842, "##row": 491, "##b": 62, "##lly": 206, "##and": 303, "pro": 347, "##grinningsquintingface": 1085, "##beamingfacewithsmilingeyes": 766, "defin": 1150, "##gu": 1171, "##disappointedfacedisappointedface": 1276, "nat": 1631, "sea": 1937, "##ast": 452, "voice": 1533, "##ons": 861, "ap": 1977, "beamingfacewithsmilingeyes": 752, "pictures": 1638, "video": 1034, "house": 1309, "stup": 374, "the": 107, "like": 173, "##sever": 1698, "##aste": 1383, "##le": 189, "##ang": 534, "##ild": 1490, "##ct": 231, "because": 335, "##pp": 174, "we": 219, "bec": 312, "block": 1051, "tv": 1444, "mach": 1817, "##ible": 1263, "run": 1961, "im": 324, "##io": 1668, "fem": 1705, "sex": 790, "##inn": 238, "must": 1106, "thought": 651, "possible": 1686, "##au": 237, "h": 23, "##red": 539, "dumb": 847, "chatting": 991, "girls": 1227, "##ood": 164, "ph": 434, "pls": 961, "##pt": 716, "##et": 150, "##out": 176, "test": 1566, "##ah": 1144, "##ical": 1173, "##no": 1191, "uh": 1203, "##ouch": 1737, "star": 1769, "except": 1945, "##t": 48, "##ds": 784, "##light": 2000, "litt": 887, "min": 1218, "##relievedface": 1355, "story": 1115, "##o": 49, "knw": 1324, "eyes": 1690, "fe": 735, "veg": 1951, "in": 129, "dare": 1813, "##ered": 1505, "fact": 1786, "whom": 1381, "party": 1385, "book": 1045, "##ize": 1640, "explain": 1098, "1": 7, "hon": 1011, "win": 1891, "night": 633, "##ooo": 1425, "ok": 154, "o": 30, "##id": 140, "ad": 629, "'": 5, "rem": 1671, "obvious": 1440, "today": 488, "been": 602, "romantic": 1691, "back": 529, "connect": 1808, "bff": 1996, "exams": 1303, "your": 144, "##wh": 895, "suggest": 1365, "bitch": 1018, "sweet": 647, "##smilingfacewithsmilingeyes": 1427, "re": 175, "enjoy": 742, "##ai": 1939, "but": 193, "##llow": 1356, "opin": 1852, "##iment": 1867, "##ound": 582, "seen": 1212, "is": 118, "un": 600, "compl": 904, "always": 474, "##x": 70, "ho": 745, "##f": 60, "mon": 632, "hear": 727, "##bigeyes": 815, "winkingfacewithtongue": 1666, "sense": 1012, "mom": 1403, "mic": 1454, "lie": 1299, "kid": 783, "##ingcatfacewithhearteyes": 1330, "easy": 1926, "##facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 1326, "##ention": 1635, "##gh": 166, "yu": 1715, "meaning": 1244, "only": 415, "lu": 1662, "here": 436, "stupid": 375, "0": 6, "##from": 1003, "##ath": 636, "know": 153, "##ol": 878, "form": 1740, "dp": 1278, "har": 775, "off": 454, "stay": 972, "tru": 1075, "joking": 1496, "soo": 1225, "##idd": 1570, "msg": 1608, "##but": 1230, "tech": 1745, "##5": 76, "friendship": 1601, "##nect": 1717, "import": 1629, "call": 397, "wtf": 1812, "##po": 550, "through": 1854, "shall": 1407, "when": 367, "old": 868, "tri": 1693, "already": 725, "around": 1286, "sure": 457, "high": 1639, "##su": 1521, "grinningcatface": 1412, "##hands": 1625, "sy": 1578, "##redheart": 1320, "ann": 662, "tra": 1202, "last": 867, "prom": 1664, "##li": 1269, "xd": 1032, "##uck": 718, "prett": 733, "##lf": 316, "do": 121, "##eak": 744, "##ily": 1048, "ug": 1327, "drink": 1314, "beach": 1772, "##ingfacewithhearteyes": 637, "g": 22, "##lieved": 863, "fin": 1089, "love": 197, "those": 1206, "##umbs": 879, "##smilingfacewithhearteyes": 988, "sor": 323, "##ach": 749, "##bly": 1390, "##son": 1047, "##alo": 1899, "##rite": 880, "us": 528, "##blem": 677, "##poutingface": 736, "pretty": 743, "gr": 966, "bot": 873, "clear": 1174, "thumbs": 1124, "yr": 1766, "leave": 674, "favor": 1188, "##vious": 1264, "al": 201, "##ilm": 1357, "s": 34, "##pression": 1272, "##ar": 170, "af": 657, "walk": 1694, "place": 793, "gonna": 795, "without": 1198, "##ck": 199, "##ird": 1274, "moment": 1826, "fam": 1142, "math": 1554, "dis": 942, "truth": 1209, "done": 777, "yester": 1311, "studying": 1543, "earth": 1871, "child": 1827, "would": 408, "gon": 779, "problem": 684, "sp": 471, "fee": 289, "##ab": 1057, "fucking": 1153, "probably": 1513, "just": 194, "col": 1526, "these": 1027, "idea": 1123, "##oose": 1537, "he": 314, "conf": 899, "##on": 88, "care": 700, "##use": 1735, "feel": 290, "absolutely": 1928, "student": 1651, "lot": 649, "##ific": 1074, "me": 99, "##eeee": 1670, "life": 424, "absol": 1825, "##ital": 2003, "ohh": 696, "hop": 765, "top": 900, "happ": 250, "gi": 1046, "want": 215, "hot": 1201, "mil": 1797, "##ustr": 1986, "##ular": 1868, "##less": 655, "mood": 854, "##faceangry": 1728, "resp": 1408, "[MASK]": 4, "##fac": 1229, "awes": 589, "ign": 1000, "##ee": 136, "ill": 1486, "##ft": 769, "fr": 689, "ob": 1143, "inde": 1525, "##ating": 584, "ice": 1998, "ve": 627, "once": 1489, "##self": 535, "robot": 928, "inst": 1559, "important": 1760, "facewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoyfacewithtearsofjoy": 1688, "##edface": 348, "second": 1850, "fine": 446, "fell": 1588, "##isappoint": 453, "##ly": 132, "pr": 1070, "en": 353, "##ash": 1270, "oo": 1455, "##right": 1131, "intelligent": 1652, "dar": 957, "numb": 567, "[SEP]": 3, "##gr": 931, "jok": 838, "luck": 913, "vo": 1395, "##inkingfacewithtongue": 1097, "inv": 1558, "again": 510, "##g": 44, "stop": 591, "words": 1285, "chocolate": 1609, "##mil": 205, "count": 1154, "lo": 138, "mad": 570, "used": 1460, "develop": 1730, "##ingfacewithsunglasses": 1655, "##ball": 1541, "##ket": 940, "neither": 1561, "engl": 715, "sarc": 1519, "##winkingface": 1678, "address": 1726, "wish": 711, "##ies": 536, "ly": 1860, "##tt": 442, "underst": 497, "hel": 512, "friend": 293, "##qu": 517, "hehe": 1061, "##ou": 78, "##ie": 279, "any": 246, "##du": 1422, "rude": 523, "##quintingface": 623, "##grinningfacewithsweat": 1304, "computer": 1469, "che": 829, "##ise": 901, "gott": 1804, "yes": 167, "wi": 1473, "alre": 723, "put": 1765, "##ph": 982, "##st": 109, "##ir": 198, "##usedface": 834, "fir": 603, "beaut": 914, "##uff": 1119, "happens": 2010, "few": 1384, "scary": 1971, "##eek": 816, "matter": 1195, "catch": 1972, "didnt": 1359, "##phone": 1974, "unt": 1938, "##ingeyes": 319, "school": 1179, "photos": 1869, "##lieve": 586, "why": 159, "##ither": 985, "by": 329, "sleeping": 1538, "##inkingface": 694, "late": 1190, "belong": 1887, "char": 1435, "waiting": 965, "##openmouth": 1729, "th": 83, "facewithtearsofjoy": 265, "hard": 906, "comput": 1376, "foot": 1763, "welcome": 542, "ms": 1371, "wait": 562, "other": 646, "##oto": 673, "##ot": 257, "mat": 856, "##ice": 304, "##ning": 690, "after": 768, "yours": 1166, "##itely": 1338, "##ow": 101, "this": 236, "[UNK]": 1, "pre": 495, "cause": 950, "missing": 1360, "micro": 1470, "##ut": 134, "into": 1529, "##ia": 1753, "sleepy": 1743, "people": 522, "##nt": 338, "##angu": 859, "##ark": 1373, "bf": 1350, "##ence": 643, "##antic": 1623, "fa": 372, "##withtong": 713, "relationship": 1134, "##time": 1312, "##ny": 296, "confus": 1409, "sun": 1866, "guess": 676, "hair": 1487, "failed": 1885, "movies": 1062, "if": 313, "hour": 1317, "##me": 110, "##facean": 1277, "##ke": 122, "##skintone": 1856, "heart": 599, "jokes": 1266, "smilingfacewithhearteyes": 932, "##grinn": 447, "##ari": 1204, "imp": 1208, "understand": 521, "damn": 974, "sry": 1962, "perfect": 1259, "##ure": 431, "dear": 739, "sick": 1819, "##perseveringface": 1970, "times": 1146, "artificial": 1551, "##ror": 1892, "change": 1040, "int": 450, "##cl": 1983, "start": 664, "looks": 1478, "pu": 1750, "wh": 85, "lucky": 1595, "##disappointedface": 702, "enjoying": 1249, "##ouble": 1546, "idk": 1882, "deser": 1957, "bir": 1117, "##rollingeyes": 1853, "watching": 1161, "##by": 496, "intell": 920, "funny": 309, "##oooo": 1127, "##withtearsof": 156, "brokenheart": 1947, "wonder": 1413, "ty": 707, "also": 461, "myself": 969, "##p": 54, "coll": 1005, "join": 1509, "wife": 1908, "confused": 1886, "##ries": 1943, "interest": 710, "##ll": 96, "se": 373, "##smil": 247, "##in": 79, "see": 300, "whats": 613, "part": 748, "interesting": 1148, "##sof": 151, "##savoringfood": 1306, "please": 339, "##ign": 1374, "mine": 947, "d": 19, "mr": 1933, "every": 418, "##ea": 86, "su": 233, "##times": 1155, "move": 1344, "##d": 56, "##ult": 1347, "##mes": 1059, "wasn": 1506, "##ages": 1245, "##ex": 1569, "frnd": 1466, "open": 1607, "bat": 1656, "hur": 520, "from": 364, "com": 451, "##is": 123, "lets": 1001, "smilingcatfacewithhearteyes": 1684, "##xt": 566, "chill": 1426, "ter": 1713, "##hhh": 1949, "wa": 507, "ex": 306, "##eep": 350, "##ates": 1397, "that": 131, "break": 778, "kiss": 891, "##lc": 524, "##ed": 114, "chic": 1784, "##riend": 228, "watched": 1530, "questions": 1102, "##grinningsquintingfacegrinningsquintingface": 1907, "ext": 2005, "things": 704, "##ightlyfrowningface": 1707, "##ich": 378, "##oot": 1329, "swee": 605, "pe": 494, "bt": 1044, "##where": 1363, "##cess": 1634, "##vice": 1555, "tiredface": 1872, "##k": 52, "ar": 592, "guys": 1414, "giving": 1708, "cou": 818, "it": 108, "##age": 469, "##brokenheartbrokenheart": 1596, "talking": 390, "##you": 1445, "sister": 1622, "plea": 326, "##osed": 1811, "okay": 368, "##al": 141, "av": 1761, "##end": 168, "##umb": 361, "##de": 876, "world": 903, "asking": 686, "##ange": 898, "##oring": 620, "##les": 1503, "weekend": 1482, "expect": 1672, "gener": 1641, "post": 1419, "##apost": 1674, "##loudlycryingfaceloudlycryingface": 493, "inf": 1488, "##skint": 1775, "days": 1017, "bel": 1582, "cute": 962, "as": 210, "ch": 225, "forg": 842, "abt": 1217, "##oz": 1667, "state": 1741, "complete": 1870, "prefer": 1787, "okk": 1965, "##ci": 1220, "##loudlycryingface": 356, "expressionlessface": 1946, "calling": 1723, "##ba": 714, "spec": 1083, "##ld": 190, "##own": 720, "cu": 1308, "trouble": 1847, "##os": 871, "eng": 1337, "##ump": 1785, "birth": 1295, "##bile": 1391, "um": 1752, "ahead": 1956, "tr": 440, "girl": 443, "dreams": 1727, "dance": 1267, "##ps": 1066, "having": 1004, "away": 976, "##apo": 981, "fool": 1443, "asked": 831, "software": 1574, "headache": 1790, "##unamusedface": 1449, "100": 2013, "side": 1862, "##ers": 360, "##ving": 531, "##4": 73, "blocked": 2014, "##ared": 1520, "##ust": 178, "fant": 1931, "coup": 1919, "piss": 1935, "##een": 1268, "co": 209, "spo": 2008, "blue": 1883, "supp": 1158, "##ingcatface": 485, "unamusedface": 1405, "clo": 1394, "intern": 1709, "cric": 1749, "x": 39, "##ve": 104, "f": 21, "big": 1050, "##ited": 1192, "own": 1197, "##el": 1463, "##not": 1522, "##wn": 956, "##fast": 1982, "slightlysmilingface": 1477, "thing": 578, "clearly": 1653, "beautiful": 1019, "##ame": 241, "profile": 1994, "room": 1388, "one": 248, "whatever": 1126, "word": 1234, "##em": 1241, "darling": 1322, "##ys": 1980, "bro": 456, "lea": 514, "##ude": 410, "##om": 476, "yo": 1567, "##are": 639, "##0": 72, "fl": 958, "ye": 1151, "telling": 1336, "pa": 1764, "cho": 933, "##dy": 509, "not": 120, "##sed": 1894, "qu": 395, "time": 341, "worst": 1649, "annoy": 667, "exactly": 1180, "class": 1361, "dream": 1022, "propose": 1988, "##ond": 1125, "##ift": 1271, "up": 274, "mu": 354, "smart": 1149, "##urb": 1275, "them": 625, "was": 229, "sleep": 463, "war": 1875, "##led": 1768, "him": 821, "##joy": 148, "exam": 762, "making": 1133, "type": 1093, "##rent": 1984, "##ould": 239, "smil": 462, "go": 172, "##lfriend": 574, "##7": 71, "##3": 77, "nope": 678, "##ance": 781, "##butrelievedface": 1924, "favou": 1339, "norm": 1975, "badly": 1879, "even": 482, "sl": 359, "were": 601, "else": 782, "##uu": 975, "upset": 1185, "reme": 1067, "living": 1942, "lunch": 1500, "##food": 1283, "##eart": 271, "res": 869, "##withtearsofjoy": 157, "##ad": 158, "least": 1592, "##nce": 1695, "##sp": 1981, "na": 619, "##steamfromn": 1014, "##a": 42, "tell": 221, "##ation": 399, "##gn": 936, "##ell": 142, "grinn": 363, "amazing": 929, "disappointedface": 788, "##ush": 865, "##estion": 500, "v": 37, "##ul": 437, "##ti": 848, "try": 477, "fast": 1068, "bl": 726, "##iv": 1767, "##ail": 935, "##ger": 1343, "sur": 1590, "##ing": 81, "has": 754, "##set": 1108, "while": 1523, "##ion": 212, "cryingface": 853, "##faceblowingakissfaceblowingakiss": 1573, "ton": 1502, "##ough": 698, "you": 80, "##other": 952, "special": 1377, "##irk": 1465, "sat": 1818, "##tiredface": 1954, "2": 8, "##int": 275, "##inner": 1007, "could": 668, "piz": 1461, "##kenheart": 945, "##il": 160, "fuck": 379, "##ways": 458, "head": 963, "##tim": 1101, "proba": 1436, "goes": 2002, "des": 1039, "##ings": 401, "sing": 680, "##irt": 2004, "catfacewithtearsofjoy": 1495, "##outh": 979, "##ingfacewithopenmouth": 1967, "and": 152, "di": 1118, "##quintingfacewithtongue": 1725, "##sure": 1462, "die": 1794, "##loudlycryingfaceloudlycryingfaceloudlycryingfaceloudlycryingface": 1221, "be": 124, "##mor": 738, "enough": 919, "anyone": 1121, "ev": 1731, "what": 102, "##us": 288, "tired": 1129, "ugly": 1367, "##ssible": 1476, "artific": 1532, "contact": 1593, "turn": 1758, "peop": 513, "thumbsup": 1485, "half": 1580, "##arc": 1448, "others": 1838, "sh": 217, "##ost": 1372, "##2": 69, "keep": 813, "till": 1677, "##un": 444, "body": 1830, "##ition": 1824, "chat": 487, "saw": 1494, "wont": 1881, "##y": 46, "nothing": 345, "kind": 630, "##med": 1606, "find": 660, "both": 1009, "##steamfromnose": 1035, "##yy": 687, "yet": 890, "disappointed": 1587, "impress": 1701, "du": 1630, "##j": 64, "yup": 809, "tonight": 1611, "thanks": 358, "det": 1858, "##outingcatface": 1844, "##sweat": 697, "all": 299, "facewithtearsofjoyfacewithtearsofjoy": 654, "use": 801, "da": 1351, "##s": 51, "k": 26, "##ff": 538, "##per": 561, "joke": 650, "leaving": 1829, "##mm": 292, "##key": 1581, "bu": 837, "interested": 1176, "microsoft": 1515, "##ric": 1178, "##blowing": 835, "##alent": 1680, "##iew": 1918, "no": 92, "##ud": 264, "##its": 1491, "self": 1029, "u": 36, "##ra": 376, "emo": 1849, "##op": 261, "same": 472, "##bb": 1696, "ref": 1985, "mistake": 1909, "kn": 146, "liked": 1993, "game": 806, "##ty": 1456, "miss": 473, "yay": 1170, "##isapp": 448, "said": 467, "bi": 1237, "##blowingakiss": 840, "##ge": 277, "to": 94, "de": 396, "i": 24, "facewith": 995, "super": 944, "kissingcatface": 1594, "##smilingface": 918, "yourself": 792, "##gry": 387, "##steamfrom": 1013, "pleasure": 1618, "loudlycryingfaceloudlycryingface": 1704, "##zy": 943, "almost": 1878, "##l": 57, "loud": 1553, "vis": 1890, "##ful": 926, "bc": 1307, "##grinningfacewithsmilingeyes": 1248, "home": 598, "lovely": 1901, "sexy": 1953, "##ink": 227, "##z": 67, "##n": 58, "tw": 907, "##roll": 1289, "father": 1997, "##aa": 771, "tur": 1261, "##lo": 480, "single": 1087, "poutingface": 1100, "phone": 763, "del": 1528, "##ause": 305, "donapost": 1876, "##pl": 449, "working": 1081, "##ryingface": 235, "english": 734, "##glasses": 1636, "dog": 1739, "##pped": 1742, "link": 1370, "##oint": 383, "able": 1616, "awesome": 596, "fav": 708, "per": 817, "totally": 1711, "##lievedface": 877, "##ral": 1821, "deep": 1930, "sal": 1912, "speak": 797, "cheated": 1468, "om": 1934, "##yyyy": 1624, "get": 242, "##ady": 669, "awww": 1702, "##ering": 1254, "##ect": 799, "##slightlysmilingface": 1923, "appreci": 1929, "going": 382, "##reen": 1780, "ne": 203, "##ary": 1058, "##cent": 1464, "##ha": 187, "12": 1747, "future": 1421, "grinningcatfacewithsmilingeyes": 1510, "happy": 336, "kinda": 1540, "hw": 1932, "depends": 1565, "girlfriend": 581, "##ous": 389, "app": 760, "##ali": 1964, "news": 1637, "##ain": 307, "photo": 679, "pain": 948, "microsof": 1498, "who": 259, "problems": 1771, "never": 435, "##ry": 125, "##line": 1163, "correct": 1341, "haha": 263, "##fer": 1157, "ru": 1665, "point": 1169, "games": 1732, "question": 530, "anyway": 1345, "bad": 315, "hurts": 1235, "opt": 1999, "dark": 2011, "##rect": 1053, "hate": 318, "langu": 874, "yaa": 1591, "amaz": 870, "##owing": 767, "##vers": 983, "yaar": 1194, "ro": 563, "short": 1966, "married": 1348, "##ations": 1846, "hell": 501, "##vel": 927, "j": 25, "disappoint": 1379, "birthday": 1497, "very": 280, "##gg": 740, "add": 992, "laugh": 682, "age": 1224, "sill": 1839, "share": 1092, "##ac": 701, "##ict": 1549, "l": 27, "cry": 631, "wat": 398, "##sh": 681, "##ract": 1673, "sounds": 1710, "fail": 1152, "haven": 1255, "online": 1310, "she": 540, "grinningface": 1073, "##fe": 344, "exc": 1116, "jo": 351, "br": 934, "some": 216, "##as": 181, "facewithsteamfromnose": 1467, "taking": 1746, "god": 1216, "##ent": 291, "##smilingeyes": 327, "goo": 1056, "welc": 537, "crazy": 996, "broken": 1393, "##llywood": 1471, "##bs": 1556, "##ment": 851, "sol": 1963, "choice": 1619, "tried": 1798, "liar": 1898, "started": 1430, "7": 13, "heard": 1511, "##ome": 281, "##ip": 608, "##he": 787, "##rit": 465, "huh": 1293, "fall": 1880, "hurting": 1810, "##ting": 421, "differe": 1189, "worried": 1968, "under": 475, "write": 1396, "##ro": 165, "nah": 1450, "youre": 1253, "meet": 571, "reason": 1256, "ey": 1352, "##ware": 1433, "nd": 1280, "depress": 1054, "shit": 1358, "quite": 1564, "anymore": 1159, "though": 845, "##ors": 1524, "##iness": 1183, "##steam": 994, "##oy": 133, "down": 894, "dec": 1809, "hi": 747, "4": 10, "nt": 1332, "usually": 2009, "fun": 245, "##ine": 285, "grinningfacewith": 691, "##ber": 804, "cr": 1031, "##ually": 597, "believe": 902, "##loor": 1744, "for": 162, "disturb": 1364, "pri": 1577, "ir": 604, "p": 31, "##ingfacewith": 222, "sc": 653, "someone": 588, "wor": 268, "##grinningface": 1076, "indeed": 1643, "##ache": 1706, "looking": 1063, "bet": 481, "##erest": 705, "more": 422, "##ingfacewithsmilingeyes": 343, "an": 106, "hand": 1130, "6": 12, "sw": 532, "##ix": 1736, "oh": 249, "laughing": 1175, "car": 1181, "conversation": 1247, "##pect": 1147, "##mouth": 1262, "##lt": 1287, "my": 126, "ins": 1120, "##ere": 177, "many": 695, "forgot": 1531, "##ate": 220, "##rowningface": 1138, "##or": 103, "boring": 862, "##uture": 1411, "20": 1420, "##rible": 1944, "ag": 425, "##ya": 1913, "##ile": 756, "##brokenheart": 1210, "company": 1921, "##cat": 433, "##ress": 502, "getting": 770, "live": 802, "##frowningface": 1432, "tot": 1536, "wearyface": 1632, "##ple": 1023, "##ames": 1301, "con": 386, "depend": 1429, "subject": 1873, "come": 419, "##oudly": 294, "##im": 407, "now": 218, "means": 656, "##ouse": 1205, "hav": 1865, "sometimes": 1167, "##confoundedface": 1976, "##big": 810, "##ist": 483, "ear": 911, "##6": 75, "excited": 1642, "search": 1604, "way": 544, "##mber": 960, "angryface": 1703, "hurt": 560, "suffering": 1793, "eating": 1889, "are": 112, "##ng": 1091, "##iz": 794, "##haha": 1290, "plans": 1920, "##ther": 384, "which": 400, "##ory": 825, "far": 1342, "ignore": 1572, "ll": 429, "on": 139, "##az": 731, "might": 1417, "boy": 572, "##if": 430, "look": 515, "##ens": 930, "got": 478, "##ual": 1389, "emot": 1689, "##ward": 1615, "##hip": 1589, "month": 1479, "##oo": 116, "##ser": 1940, "##za": 1423, "india": 814, "##act": 703, "##ret": 1084, "bollywood": 1848, "picture": 717, "tea": 855, "grinningfacewithsweat": 1010, "##ored": 973, "te": 617, "##ourse": 800, "take": 611, "soon": 949, "remember": 1135, "trying": 905, "##lling": 1605, "rain": 1911, "diff": 939, "tomorrow": 823, "silly": 1874, "##eal": 923, "cur": 1659, "family": 1452, "##reat": 413, "wri": 1645, "little": 889, "##app": 394, "##ear": 115, "year": 675, "sent": 875, "##ese": 886, "message": 971, "obviously": 1654, "reply": 612, "ur": 298, "female": 1927, "##grinningfacewith": 884, "##ightly": 826, "##es": 95, "heal": 1816, "##na": 371, "##catface": 460, "##ess": 276, "##ial": 712, "##body": 846, "such": 924, "idiot": 1094, "good": 183, "his": 1815, "college": 1223, "topic": 1065, "##ss": 337, "annoying": 758, "##amusedface": 910, "favorite": 1251, "##all": 1281, "##gin": 1840, "##e": 53, "el": 729, "##ight": 244, "##pensiveface": 1914, "haa": 2001, "cold": 1562, "list": 732, "whatsapp": 1756, "told": 822, "its": 369, "##ard": 941, "##ill": 184, "##ery": 267, "bus": 671, "full": 1033, "gift": 1777, "##art": 388, "nice": 402, "li": 127, "##unch": 1228, "isn": 1265, "fake": 1112, "serious": 665, "##astic": 1437, "intellig": 938, "##hand": 977, "think": 286, "##ity": 666, "become": 1617, "##ling": 1026, "##lessface": 1362, "between": 1724, "##iful": 1008, "send": 333, "##sol": 1782, "match": 1884, "works": 1903, "##ction": 1507, "boys": 1904, "alright": 1398, "##ingcatfacewithsmilingeyes": 1294, "##ensiveface": 1431, "fan": 1612, "chicken": 1834, "##ister": 1539, "made": 724, "everyone": 908, "exact": 1110, "poor": 1258, "true": 881, "##so": 137, "##th": 89, "pur": 1239, "waste": 1820, "useless": 1906, "study": 839, "shy": 1987, "n": 29, "sch": 1071, "grinningfacewithbigeyes": 1088, "since": 1518, "didn": 526, "hahaha": 659, "ca": 310, "##ue": 412, "country": 1441, "##sk": 1568, "ohk": 1682, "##thumbs": 1738, "c": 18, "3": 9, "honey": 1992, "text": 986, "a": 16, "ha": 105, "wt": 915, "m": 28, "##lie": 565, "rea": 200, "won": 610, "aww": 1186, "before": 1079, "baby": 595, "relation": 1064, "##ower": 1915, "date": 1279, "5": 11, "bored": 1141, "yea": 253, "experience": 1888, "##gs": 1334, "pers": 564, "knew": 1319, "team": 1863, "##quint": 621, "##ingcatfacewith": 858, "##withtongue": 759, "emoj": 2015, "##winkingfacewithtongue": 1718, "##bu": 1802, "9": 15, "##ey": 191, "##fore": 1060, "sn": 1734, "prof": 1428, "gud": 1576, "found": 1795, "dam": 897, "where": 302, "##ant": 568, "##ically": 1805, "fo": 1099, "irrit": 640, "cra": 791, "busy": 819, "maybe": 866, "##poutingfacepoutingface": 1072, "att": 1199, "business": 1832, "##hhhh": 1755, "man": 426, "bed": 909, "##my": 1893, "##der": 984, "##gl": 590, "sugg": 1145, "hmmm": 1222, "fore": 1369, "nor": 1424, "or": 320, "seems": 1483, "free": 959, "devel": 1722, "dnt": 1316, "cl": 580, "years": 1236, "wht": 953, "definitely": 1378, "so": 111, "##w": 63, "reading": 1754, "buddy": 1792, "##ean": 252, "really": 269, "job": 757, "##ap": 1095, "##facesavoringfood": 1896, "let": 420, "##hh": 468, "marry": 1406, "alone": 541, "##wor": 1823, "wow": 503, "##v": 59, "follow": 1627, "dri": 1960, "kidding": 989, "##earsof": 155, "play": 579, "##ire": 1353, "red": 955, "##beamingfacewithsmilingeyesbeamingfacewithsmilingeyes": 1451, "using": 1685, "meant": 1346, "r": 33, "##h": 45, "##um": 251, "##eary": 917, "thats": 645, "cannot": 1719, "##hearteyes": 549, "btw": 1626, "dont": 328, "##8": 68, "seriously": 1042, "happen": 533, "##ply": 559, "##ingsquintingface": 728, "coz": 1128, "smile": 1584, "than": 211, "##c": 61, "lik": 1231, "everything": 746, "breakfast": 1991, "##rr": 1380, "sa": 380, "##ead": 622, "typ": 1481, "##earch": 1547, "wom": 1182, "appre": 1905, "exper": 1700, "sending": 1828, "together": 1557, "##pe": 366, "cont": 824, "non": 1733, "##sy": 1841, "music": 1055, "loudlycryingface": 652, "##ying": 1318, "##facewith": 980, "being": 576, "sim": 1333, "mo": 207, "##ions": 1232, "##rong": 551, "##ew": 1211, "##ay": 113, "they": 470, "##fl": 1633, "her": 569, "##ide": 545, "##open": 1600, "gotta": 1959, "proper": 1969, "cant": 888, "feelings": 1257, "##ightlysmilingface": 1049, "forever": 1544, "ga": 1836, "answer": 609, "answ": 548, "##rown": 872, "babe": 1563, "idi": 1077, "##ale": 1137, "##pend": 1323, "suck": 1806, "##oud": 270, "movie": 486, "##ding": 641, "over": 730, "am": 135, "##ream": 785, "##sm": 1535, "##ingfacewithsun": 1648, "sit": 1252, "loved": 1291, "##ree": 585, "kill": 1692, "le": 709, "creat": 990, "end": 1037, "comm": 1006, "op": 1090, "##i": 55, "soft": 1446, "next": 844, "eat": 833, "lame": 1296, "angry": 445, "hey": 546, "##en": 97, "crush": 1401, "##kk": 1078, "indi": 642, "##eamingfacewithsmilingeyes": 505, "sec": 892, "hang": 1661, "st": 163, "brother": 1571, "money": 921, "show": 672, "ti": 284, "##most": 1801, "refer": 1916, "cheat": 1015, "their": 1114, "##ake": 282, "coming": 1041, "##ture": 699, "songs": 1480, "chocol": 1416, "playing": 1687, "##est": 240, "thnx": 1925, "##ig": 377, "buy": 1542, "##rowningfacewithopenmouth": 1973, "lost": 504, "##ven": 391, "##sl": 1214, "gen": 1105, "hello": 937, "##con": 1328, "hor": 1387, "went": 1714, "##friend": 860, "##oudlycryingface": 297, "##cing": 1822, "glad": 827, "hope": 812, "sad": 340, "##av": 968, "left": 1111, "woman": 1773, "secret": 1246, "ready": 1475, "##the": 1895, "##re": 84, "saying": 741, "##se": 130, "there": 381, "choc": 1349, "bas": 1657, "weird": 1458, "ac": 893, "depressed": 1305, "says": 1683, "cuz": 1759, "work": 441, "how": 169, "expl": 1016, "hu": 786, "repl": 1560, "##ming": 1716, "wan": 185, "irritating": 798, "relievedface": 1335, "ma": 1861, "comp": 852, "##man": 1297, "##conf": 1675, "broke": 832, "##q": 66, "##iddle": 1776, "##ank": 1843, "nons": 1501, "cool": 414, "had": 558, "aren": 1207, "hum": 618, "##00": 1646, "##eam": 405, "##oom": 1024, "will": 226, "dist": 1104, "grinningsquintingface": 1132, "##9": 74, "ya": 427, "##am": 230, "right": 428, "##ute": 722, "week": 850, "winkingface": 1799, "long": 776, "anything": 556, "##fect": 1096, "then": 258, "friends": 508, "language": 998, "different": 1442, "##be": 607, "ab": 192, "##ken": 615, "id": 721, "##ert": 1781, "thank": 234, "program": 1168, "##ace": 93, "##ass": 624, "milk": 1995, "supposed": 1833, "either": 1552, "say": 262, "##iss": 321, "##faceblowingakiss": 1028, "morning": 1020, "##ushedface": 1922, "scared": 1952, "mess": 661, "tht": 1669, "smilingface": 1948, "choose": 1845, "##zz": 1392, "name": 459, "mov": 411, "cat": 803, "##er": 90, "##severingface": 1712, "##day": 349, "##sw": 392, "##an": 91, "##lock": 964, "doesn": 1002, "low": 1621, "##ensive": 1331, "##guish": 1628, "##log": 1459, "books": 1851, "##ally": 417, "##eyes": 232, "##gether": 1508, "##ub": 1282, "##with": 117, "listen": 808, "help": 525, "##rinn": 255, "##ith": 100, "##m": 50, "human": 692, "don": 128, "may": 554, "crying": 1086, "##ak": 334, "coffee": 1499, "##ject": 1164, "water": 1650, "thr": 1504, "##r": 47, "came": 1658, "##ool": 342, "boyfriend": 882, "each": 1859, "lol": 409, "##ri": 145, "thinking": 1109, "si": 1219, "mobile": 1434, "pic": 370, "gu": 455, "real": 616, "thn": 1354, "two": 1200, "mist": 1516, "city": 1368, "##bro": 1107, "wear": 1614, "str": 1548, "hmm": 499, "bo": 557, "smilingfacewithsmilingeyes": 1187, "##rom": 325, "vide": 951, "cricket": 1855, "##ive": 317, "personal": 1770, "art": 1103, "##orm": 1941, "##cept": 1184, "jud": 1978, "iam": 841, "at": 295, "wanted": 1325, "dr": 1043, "rel": 796, "stuff": 1243, "compliment": 1955, "mar": 626, "##ship": 1069, "too": 256, "lear": 954, "of": 149, "##our": 492, "pick": 1493, "##sav": 1298, "##cryingface": 254, "##oringfood": 1302, "##any": 1598, "##aught": 1681, "##wearyface": 1864, "mor": 946, "out": 432, "sub": 1699, "shut": 849, "spe": 967, "happened": 772, "##facewithtearsofjoy": 213, "##bot": 828, "watch": 516, "stud": 628, "##iff": 925, "##thing": 195, "learn": 1250, "##ack": 1213, "rom": 1410, "pizza": 1514, "need": 385, "song": 635, "##outingface": 577, "##grinningfacewithbigeyes": 1439, "gf": 807, "redheart": 1586, "rec": 1720, "sound": 1177, "whole": 1527, "##ver": 182, "dat": 1762, "film": 1796, "can": 161, "thou": 484, "##ated": 753, "actually": 670, "grinningfacewithsmilingeyes": 1160, "did": 243, "w": 38, "##ick": 706, "bea": 751, "##llyw": 1457, "less": 1663, "plan": 843, "person": 575, "##ween": 1697, "mean": 278, "gl": 773, "called": 1193, "umm": 1613, "##reak": 719, "engine": 1512, "pass": 1517, "about": 208, "does": 552, "forget": 1415, "[PAD]": 0, "pics": 1382, "##cryingfacecryingface": 1233, "##pose": 1321, "##ather": 1082, "talk": 214, "##way": 355, "acc": 1400, "give": 423, "health": 1910, "later": 1038, "##irkingface": 1534, "faceblowingakiss": 1783, "hm": 1660, "our": 883, "plz": 857, "##old": 1800, "##ved": 1897, "##ness": 1474, "##blowingak": 836, "sha": 648, "b": 17, "##ort": 683, "##bo": 750, "gir": 331, "dude": 922, "##ete": 1288, "rest": 1603, "##sun": 1597, "##face": 98, "tak": 1240, "mus": 987, "##ble": 404, "##ch": 143, "nobody": 1242, "##ter": 283, "wouldn": 1989, "bye": 547, "bit": 658, "eas": 1644, "better": 553, "yeah": 287, "lonely": 805, "em": 820, "lat": 912, "sm": 737, "yesterday": 1315, "much": 362, "women": 1958, "##ag": 587, "##ince": 1215, "##rim": 1447, "##earyface": 1139, "##ats": 1418, "##asses": 1602, "##ingface": 119, "new": 593, "##rimac": 1610, "##isappointedface": 489, "##ore": 406, "##ious": 498, "ohhh": 1080, "##ind": 301, "convers": 1226, "cor": 1238, "worry": 1404, "##alk": 202, "suff": 1583, "first": 614, "hy": 1814, "imag": 1807, "reg": 1877, "##ic": 180, "ce": 1835, "should": 439, "read": 896, "##ount": 1679, "def": 1030, "##press": 634, "stand": 1900, "##bl": 693, "wrong": 583, "inter": 1172, "##ics": 1721, "##at": 82, "still": 594, "t": 35, "##ite": 543, "number": 606, "most": 761, "met": 1599, "##able": 1052, "##rimacingface": 1620, "ai": 774, "##arent": 1831, "##smilingfacewithhearteyessmilingfacewithhearteyes": 2012, "##orn": 1165, "##di": 464, "##ience": 1156, "well": 322, "indian": 1399, "trust": 1292, "##ever": 916, "e": 20, "##ww": 830, "##more": 1136, "favourite": 1386, "with": 186, "##ves": 864, "##ose": 663, "##thers": 1647, "yess": 1803, "wo": 1579, "missed": 1122, "##ied": 1402, "##ional": 1902, "tat": 1979, "trash": 2007, "intelligence": 1313, "tomor": 811, "seem": 970, "##facewithsteamfromnose": 1757, "ask": 308, "##eat": 527, "course": 885, "loves": 1300, "guy": 755, "ser": 555} \ No newline at end of file diff --git a/test/data/sentencepiece.bpe.model b/test/data/sentencepiece.bpe.model new file mode 100644 index 00000000..db9af13b Binary files /dev/null and b/test/data/sentencepiece.bpe.model differ diff --git a/test/test_tools_add_pre_post_processing_to_model.py b/test/test_tools_add_pre_post_processing_to_model.py index 9c337f16..b43a433e 100644 --- a/test/test_tools_add_pre_post_processing_to_model.py +++ b/test/test_tools_add_pre_post_processing_to_model.py @@ -7,7 +7,7 @@ import io import numpy as np import onnxruntime as ort import os -import sys +from typing import List from PIL import Image from pathlib import Path @@ -16,6 +16,9 @@ from distutils.version import LooseVersion # `pip install -e .` from the repo root. from onnxruntime_extensions import get_library_path from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp +from onnxruntime_extensions.tools import pre_post_processing as pre_post_processing +from onnxruntime_extensions.tools.pre_post_processing.steps import * + script_dir = os.path.dirname(os.path.realpath(__file__)) ort_ext_root = os.path.abspath(os.path.join(script_dir, "..")) @@ -59,8 +62,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): input_batch = input_tensor.unsqueeze( 0).detach().cpu().numpy() # create a mini-batch as expected by the model - s = ort.InferenceSession(input_model) - scores = s.run(None, {'x': np.array(input_batch)}) + s = ort.InferenceSession(input_model, providers=['CPUExecutionProvider']) + scores = s.run(None, {"x": np.array(input_batch)}) scores = np.squeeze(scores) def softmax(x): @@ -75,8 +78,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): so = ort.SessionOptions() so.register_custom_ops_library(get_library_path()) - s = ort.InferenceSession(output_model, so) - probabilities = s.run(None, {'image': np.array(input_bytes)})[0] + s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider']) + probabilities = s.run(None, {"image": np.array(input_bytes)})[0] probabilities = np.squeeze(probabilities) # remove batch dim return probabilities @@ -100,6 +103,7 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): # can still use PT pre-processing as it's using PIL for images. # Update the Normalize values to match TF requirements. from torchvision import transforms + input_image = Image.open(input_image_path) preprocess = transforms.Compose([ transforms.Resize(256), @@ -108,12 +112,13 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) input_tensor = preprocess(input_image) - input_batch = input_tensor.unsqueeze( - 0).detach().cpu().numpy() # create a mini-batch as expected by the model - input_batch = np.transpose(input_batch, (0, 2, 3, 1)) # to NHWC format for TF input + # create a mini-batch as expected by the model + input_batch = (input_tensor.unsqueeze(0).detach().cpu().numpy()) + # to NHWC format for TF input + input_batch = np.transpose(input_batch, (0, 2, 3, 1)) - s = ort.InferenceSession(input_model) - probabilities = s.run(None, {'input': np.array(input_batch)})[0] + s = ort.InferenceSession(input_model, providers=['CPUExecutionProvider']) + probabilities = s.run(None, {"input": np.array(input_batch)})[0] return np.squeeze(probabilities) def new_output(): @@ -123,8 +128,8 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): so = ort.SessionOptions() so.register_custom_ops_library(get_library_path()) - s = ort.InferenceSession(output_model, so) - probabilities = s.run(None, {'image': np.array(input_bytes)})[0] + s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider']) + probabilities = s.run(None, {"image": np.array(input_bytes)})[0] return np.squeeze(probabilities) # remove batch dim orig_results = orig_output() @@ -158,14 +163,14 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): so = ort.SessionOptions() so.register_custom_ops_library(get_library_path()) - s = ort.InferenceSession(output_model, so) + s = ort.InferenceSession(output_model, so, providers=['CPUExecutionProvider']) - result_bytes = s.run(None, {'image': np.array(input_bytes)})[0] + result_bytes = s.run(None, {"image": np.array(input_bytes)})[0] # convert from png to RGB to remove any png encoding diffs result_img = Image.open(io.BytesIO(result_bytes)) - result = np.array(result_img.convert('RGB')) - expected = np.array(Image.open(expected_output_image_path).convert('RGB')) + result = np.array(result_img.convert("RGB")) + expected = np.array(Image.open(expected_output_image_path).convert("RGB")) # check all pixel values are close. allowing for 0.1% of pixels to differ by 2 at the most. # @@ -174,8 +179,147 @@ class TestToolsAddPrePostProcessingToModel(unittest.TestCase): # whether avx512 is used or not. MacOS seems to be slightly worse though (max of 2) diffs = np.absolute(expected.astype(np.int32) - result.astype(np.int32)) total = np.sum(diffs) - print(f'Max diff:{diffs.max()} Total diffs:{total}') + print(f"Max diff:{diffs.max()} Total diffs:{total}") self.assertTrue(diffs.max() < 3 and total < (result.size / 1000)) + def create_pipeline_and_run_for_tokenizer(self, tokenizer_impl, tokenizer_type, + tokenizer_parameters, output_model: Path): + import onnx + create_named_value = pre_post_processing.utils.create_named_value + + inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])] + + pipeline = pre_post_processing.PrePostProcessor(inputs) + # ref_output = list(tokenizer(*input_text, return_tensors="np").values()) + + pipeline.add_pre_processing([tokenizer_impl]) + if tokenizer_type == "HfBertTokenizer_with_decoder": + pipeline.add_post_processing([ + (BertTokenizerQADecoder(tokenizer_parameters), [ + pre_post_processing.utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)]) + ]) + input_model = onnx.load(os.path.join(test_data_dir, "../bert_qa_decoder_base.onnx")) + else: + input_model = onnx.ModelProto() + input_model.opset_import.extend([onnx.helper.make_operatorsetid("", 16)]) + new_model = pipeline.run(input_model) + onnx.save_model(new_model, output_model) + return + + def test_sentencepiece_tokenizer(self): + output_model = (Path(test_data_dir) / "../sentencePiece.onnx").resolve() + + input_text = ("This is a test sentence",) + ref_output = [np.array([[0, 3293, 83, 10, 3034, 149357, 2]]), + np.array([[1, 1, 1, 1, 1, 1, 1]])] + # tokenizer = transformers.AutoTokenizer.from_pretrained("xlm-roberta-base") + tokenizer_parameters = TokenizerParam( + vocab_or_file=os.path.join(test_data_dir, "../sentencepiece.bpe.model"), + tweaked_bos_id=0, + ) + tokenizer_impl = SentencePieceTokenizer(tokenizer_parameters, add_eos=True, add_bos=True) + self.create_pipeline_and_run_for_tokenizer( + tokenizer_impl, "SentecePieceTokenizer", tokenizer_parameters, output_model) + + so = ort.SessionOptions() + so.register_custom_ops_library(get_library_path()) + s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"]) + + result = s.run(None, {s.get_inputs()[0].name: np.array([[*input_text]])}) + + # SentencePieceTokenizer in ORT is round to zero, so we need to use atol=1 + self.assertEqual(np.allclose(result[0], ref_output[0], atol=1), True) + self.assertEqual(np.allclose(result[1], ref_output[1]), True) + + def test_bert_tokenizer(self): + output_model = (Path(test_data_dir) / "../bert_tokenizer.onnx").resolve() + input_text = ("This is a test sentence",) + ref_output = [ + np.array([[2, 236, 118, 16, 1566, 875, 643, 3]]), + np.array([[0, 0, 0, 0, 0, 0, 0, 0]]), + np.array([[1, 1, 1, 1, 1, 1, 1, 1]]), + ] + # tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert") + tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../bert.vocab"), + do_lower_case=True) + tokenizer_impl = BertTokenizer(tokenizer_parameters) + self.create_pipeline_and_run_for_tokenizer( + tokenizer_impl, "BertTokenizer", tokenizer_parameters, output_model) + + so = ort.SessionOptions() + so.register_custom_ops_library(get_library_path()) + s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"]) + + result = s.run(None, {s.get_inputs()[0].name: np.array([[*input_text]])}) + + self.assertEqual(np.allclose(result[0], ref_output[0]), True) + self.assertEqual(np.allclose(result[1], ref_output[2]), True) + self.assertEqual(np.allclose(result[2], ref_output[1]), True) + + def test_hfbert_tokenizer(self): + output_model = (Path(test_data_dir) / "../hfbert_tokenizer.onnx").resolve() + ref_output = ([ + np.array([[2, 236, 118, 16, 1566, 875, 643, 3, 236, 118, 978, 1566, 875, 643, 3]]), + np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]), + np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + ] + ) + # tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert") + tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../hfbert.vocab"), + do_lower_case=True, is_sentence_pair=True) + input_text = ("This is a test sentence", "This is another test sentence") + tokenizer_impl = BertTokenizer(tokenizer_parameters) + self.create_pipeline_and_run_for_tokenizer( + tokenizer_impl, "HfBertTokenizer", tokenizer_parameters, output_model) + + so = ort.SessionOptions() + so.register_custom_ops_library(get_library_path()) + s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"]) + + result = s.run(None, {s.get_inputs()[0].name: np.array([[input_text[0], input_text[1]]])}) + + self.assertEqual(np.allclose(result[0], ref_output[0]), True) + self.assertEqual(np.allclose(result[1], ref_output[2]), True) + self.assertEqual(np.allclose(result[2], ref_output[1]), True) + + def test_qatask_with_tokenizer(self): + output_model = (Path(test_data_dir) / "../hfbert_tokenizer.onnx").resolve() + ref_output = [np.array(["[CLS]"])] + # tokenizer = transformers.AutoTokenizer.from_pretrained("lordtt13/emo-mobilebert") + tokenizer_parameters = TokenizerParam(vocab_or_file=os.path.join(test_data_dir, "../hfbert.vocab"), + do_lower_case=True, is_sentence_pair=True) + input_text = ("This is a test sentence", "This is another test sentence") + tokenizer_impl = BertTokenizer(tokenizer_parameters) + self.create_pipeline_and_run_for_tokenizer( + tokenizer_impl, "HfBertTokenizer_with_decoder", tokenizer_parameters, output_model) + + so = ort.SessionOptions() + so.register_custom_ops_library(get_library_path()) + s = ort.InferenceSession(str(output_model), so, providers=["CPUExecutionProvider"]) + + result = s.run(None, {s.get_inputs()[0].name: np.array([[input_text[0], input_text[1]]])}) + + self.assertEqual(result[0][0], ref_output[0][0]) + + # Corner Case + def test_debug_step(self): + import onnx + + create_named_value = pre_post_processing.utils.create_named_value + + # multiple DebugSteps are stringed together + input_model_path = os.path.join(test_data_dir, "pytorch_super_resolution.onnx") + inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])] + pipeline = pre_post_processing.PrePostProcessor(inputs) + # each Debug step adds a new model output + post_processing = [pre_post_processing.Debug(1), pre_post_processing.Debug(1), pre_post_processing.Debug(1)] + + pipeline.add_post_processing(post_processing) + input_model = onnx.load(input_model_path) + new_model = pipeline.run(input_model) + + self.assertEqual(len(new_model.graph.output), len(input_model.graph.output) + len(post_processing)) + + if __name__ == "__main__": unittest.main() diff --git a/tutorials/bert_e2e.py b/tutorials/bert_e2e.py new file mode 100644 index 00000000..6b41d2c1 --- /dev/null +++ b/tutorials/bert_e2e.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import os +import shutil +import re +import tempfile +import functools +from pathlib import Path + +from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp +import onnxruntime_extensions + +# for tokenizer +import transformers +import numpy as np +import onnxruntime + + +# avoid loading model from hugging-face multiple times, it's time consuming +@functools.lru_cache +def get_tokenizer_and_model_from_huggingface(model_name): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + config = transformers.AutoConfig.from_pretrained(model_name) + + if model_name == "xlm-roberta-base": + model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name) + onnx_config = transformers.models.xlm_roberta.XLMRobertaOnnxConfig(config, "sequence-classification") + text = ("Hello, my dog is cute",) + elif model_name == "google/mobilebert-uncased": + model = transformers.MobileBertForNextSentencePrediction.from_pretrained(model_name) + onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "default") + text = ("where is Jim Henson?", "he is at school from where two blocks away") + elif model_name == "csarron/mobilebert-uncased-squad-v2": + model = transformers.MobileBertForQuestionAnswering.from_pretrained(model_name) + onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "question-answering") + text = ("Who was Jim Henson?", "Jim Henson was a nice puppet") + elif model_name == "lordtt13/emo-mobilebert": + model = transformers.MobileBertForSequenceClassification.from_pretrained(model_name) + onnx_config = transformers.models.mobilebert.MobileBertOnnxConfig(config, "sequence-classification") + text = ("Hello, my dog is cute",) + else: + raise ValueError(f"{model_name} is not supported yet.") + return tokenizer, model, onnx_config, text + + +def export_backbone(model_name: str, bert_onnx_model: Path): + """ + To export onnx model from huggingface. + This model usually has inputs "input_ids", "attention_mask", "token_type_ids", and tensor outputs. + """ + + # fix the seed so we can reproduce the results + transformers.set_seed(42) + tokenizer, model, onnx_config, text = get_tokenizer_and_model_from_huggingface(model_name) + + if bert_onnx_model and bert_onnx_model.exists(): + print("Using cached ONNX model, skipping re-exporting the backbone model.") + return tokenizer, bert_onnx_model, onnx_config + + # tempfile will be removed automatically + with tempfile.TemporaryDirectory() as tmpdir: + canonized_name = bert_onnx_model.name + tmp_model_path = Path(tmpdir + "/" + canonized_name) + onnx_inputs, onnx_outputs = transformers.onnx.export(tokenizer, model, onnx_config, 16, tmp_model_path) + shutil.copy(tmp_model_path, bert_onnx_model) + return tokenizer, bert_onnx_model, onnx_config + + +def add_pre_post_processing_to_transformers(model_name: str, input_model_file: Path, output_model_file: Path): + """Construct the pipeline for an end2end model with pre and post processing. + The final model can take text as inputs and output the result in text format for models like Q & A. + + Args: + model_name (str): Model to export from hugging-face. Used to infer tokenizer and onnx model backbone. + input_model_file (Path): The onnx model needed to be saved/cached, if not provided, will export from hugging-face. + output_model_file (Path): where to save the final onnx model. + """ + tokenizer, bert_onnx_model, onnx_config = export_backbone(model_name, input_model_file) + if not hasattr(tokenizer, "vocab_file"): + vocab_file = bert_onnx_model.parent / "vocab.txt" + import json + with open(str(vocab_file), 'w') as f: + f.write(json.dumps(tokenizer.vocab)) + else: + vocab_file = tokenizer.vocab_file + tokenizer_type = 'BertTokenizer' if model_name != 'xlm-roberta-base' else 'SentencePieceTokenizer' + task_type = ('NextSentencePrediction' if model_name == 'google/mobilebert-uncased' + else ''.join([i.capitalize() for i in onnx_config.task.split('-')])) + add_ppp.transformers_and_bert(bert_onnx_model, output_model_file, + vocab_file, tokenizer_type, + task_type, + add_debug_before_postprocessing=True) + + +def verify_results_for_e2e_model(model_name: str, input_bert_model: Path, output_model_file: Path): + """ + Args: + output_model_file: the onnx model which finalized and needs to be verified + model_name: the huggingface model name + input_bert_model: the onnx model which is generated by huggingface or user provide + """ + tokenizer, hg_model, _, text = get_tokenizer_and_model_from_huggingface(model_name) + encoded_input = tokenizer(*text, return_tensors="pt") + transformers.set_seed(42) + + session_options = onnxruntime.SessionOptions() + + output_name_for_verify = '' + session = onnxruntime.InferenceSession( + str(input_bert_model.resolve(strict=True)), providers=["CPUExecutionProvider"] + ) + inputs = {key: value.detach().numpy() for key, value in encoded_input.items()} + output_name_for_verify = session.get_outputs()[0].name + ref_outputs = session.run([output_name_for_verify], inputs) + + # Load tokenizer op + session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path()) + + session = onnxruntime.InferenceSession( + str(output_model_file.resolve(strict=True)), session_options, providers=["CPUExecutionProvider"] + ) + + inputs = dict(input_text=np.array([[*text]])) + real_outputs = session.run([output_name_for_verify+"_debug"], inputs) + assert np.allclose( + real_outputs[0], ref_outputs[0], atol=1e-2, rtol=1e-6 + ), f"Results do not match, expected:{ref_outputs[0]}, but got {real_outputs[0] }" + + print("Results matches:", real_outputs[0], "\ndiff:", real_outputs[0] - ref_outputs[0]) + + +def main(): + parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="""Add pre and post processing to a model. + + This tutorial supports updating: + - MobileBert with different tasks + - XLM-Roberta with classification task + + This tutorial provides an example of how to add pre/post processing to a transformer model. + It can add a tokenizer (SentencePiece/Berttokenizer/HfbertTokenizer) for pre-processing, + and a classifier/decoder for post-processing. + + Exports models from huggingface by default if an existing onnx model is not provided. + NOTE: if providing a onnx model, you have to make sure your model is matched with the model_type in hugging-face as we are using the hugging-face tokenizer to do the pre-processing. + """, + ) + + parser.add_argument( + "-t", + "--model_type", + type=str, + required=True, + choices=[ + "xlm-roberta-base", + "google/mobilebert-uncased", + "csarron/mobilebert-uncased-squad-v2", + "lordtt13/emo-mobilebert", + ], + help="Model type.", + ) + + parser.add_argument( + "model_path", + type=Path, + help="""The path to an existing ONNX model or directory name to save a model exported from HuggingFace in. + This model will be updated to add pre/post processing, and saved in the same location with the suffix + '.with_pre_post_processing.onnx'""", + ) + + args = parser.parse_args() + + model_path = args.model_path.resolve(strict=True) + canonized_name = re.sub(r"[^a-zA-Z0-9]", "_", args.model_type) + ".onnx" + + if model_path.is_dir(): + model_path = model_path / canonized_name + + new_model_path = model_path.with_suffix(".with_pre_post_processing.onnx") + + add_pre_post_processing_to_transformers(args.model_type, model_path, new_model_path) + verify_results_for_e2e_model(args.model_type, model_path, new_model_path) + return new_model_path + + +if __name__ == "__main__": + main()