add a gen_processing_model option to cast token-id for int64 (#632)
* add a gen_processing_model option to cast token-id for int64 * Update util.py test pipeline trigger
This commit is contained in:
Родитель
b072e94afd
Коммит
a32b932547
|
@ -23,6 +23,7 @@ class PyCustomOpDef:
|
|||
dt_complex64: int = ...
|
||||
dt_complex128: int = ...
|
||||
dt_bfloat16: int = ...
|
||||
|
||||
def install_hooker(self, invocation_handler: Callable) -> None:
|
||||
...
|
||||
...
|
||||
|
|
|
@ -9,8 +9,6 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
|||
|
||||
import json
|
||||
import onnx
|
||||
import uuid
|
||||
import numpy as np
|
||||
from numpy import array as nparray
|
||||
from functools import partial
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
@ -31,7 +29,8 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
# ids = sorted(hf_tokenizer.added_tokens_encoder.values())
|
||||
# if not ids == list(range(min(ids), max(ids) + 1)):
|
||||
# raise RuntimeError(f"{hf_tokenizer.__name__}: the ids in added_tokens_encoder are not consecutive")
|
||||
token_map = [f"{_k}={_v}" for _k, _v in hf_tokenizer.added_tokens_encoder.items()]
|
||||
token_map = [f"{_k}={_v}" for _k,
|
||||
_v in hf_tokenizer.added_tokens_encoder.items()]
|
||||
attrs.update({"added_token": "\n".join(token_map)})
|
||||
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_tokenizer.bpe_ranks.items()}
|
||||
|
@ -42,7 +41,8 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
raise ValueError(
|
||||
'Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
|
||||
attrs = self.convert_bpe_vocab(hf_gpt2_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
|
@ -51,12 +51,15 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def bert_tokenizer(self, **kwargs):
|
||||
hf_bert_tokenizer = self.tokenizer
|
||||
# has to be sorted since the id of token was generated automatically.
|
||||
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||
ordered_vocab = OrderedDict(
|
||||
sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||
vocab = '\n'.join(ordered_vocab.keys())
|
||||
attrs = dict(vocab=vocab)
|
||||
init_kwargs = hf_bert_tokenizer.init_kwargs
|
||||
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0
|
||||
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0
|
||||
attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get(
|
||||
'do_lower_case') else 0
|
||||
attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get(
|
||||
'strip_accents') else 0
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
@ -91,7 +94,8 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
hf_clip_tokenizer = self.tokenizer
|
||||
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
|
||||
raise ValueError(
|
||||
'Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
|
||||
|
||||
attrs = self.convert_bpe_vocab(hf_clip_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
|
@ -101,7 +105,8 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
hf_roberta_tokenizer = self.tokenizer
|
||||
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
|
||||
raise ValueError(
|
||||
'Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
|
||||
|
||||
attrs = self.convert_bpe_vocab(hf_roberta_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
|
@ -133,7 +138,7 @@ _PROCESSOR_DICT = {
|
|||
"DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
|
@ -167,7 +172,8 @@ class HFTokenizerOnnxGraph:
|
|||
|
||||
@staticmethod
|
||||
def extract_cls_name(processor):
|
||||
cls_name = processor if isinstance(processor, str) else type(processor).__name__
|
||||
cls_name = processor if isinstance(
|
||||
processor, str) else type(processor).__name__
|
||||
if cls_name.endswith("TokenizerFast"):
|
||||
cls_name = cls_name[:-len("Fast")]
|
||||
return cls_name
|
||||
|
@ -184,6 +190,8 @@ class HFTokenizerOnnxGraph:
|
|||
|
||||
def pre_processing(self, **kwargs):
|
||||
with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)
|
||||
cast_token_id = kwargs.pop("CAST_TOKEN_ID", False)
|
||||
|
||||
_cvt_op = self.cvt_quadruple.pre_op
|
||||
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
|
@ -200,22 +208,41 @@ class HFTokenizerOnnxGraph:
|
|||
if self.cvt_quadruple.default_inputs is not None:
|
||||
default_inputs.update(self.cvt_quadruple.default_inputs)
|
||||
if len(default_inputs) != n_inputs:
|
||||
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||
raise ValueError(
|
||||
"Op: {} does not have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||
|
||||
new_initializers = []
|
||||
|
||||
for k, v in default_inputs.items():
|
||||
input_value_info = next((i for i in g.input if i.name == k), None)
|
||||
if input_value_info is None:
|
||||
raise ValueError("The input {} is not found in the graph".format(k))
|
||||
raise ValueError(
|
||||
"The input {} is not found in the graph".format(k))
|
||||
|
||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
|
||||
input_value_info.type.tensor_type.elem_type)
|
||||
value = nparray(v, np_dtype)
|
||||
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
||||
g.initializer.extend(new_initializers)
|
||||
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
||||
g.ClearField("input")
|
||||
g.input.extend(new_inputs)
|
||||
|
||||
if cast_token_id:
|
||||
# assume the first output is always the token ID.
|
||||
if g.output[0].type.tensor_type.elem_type != onnx.onnx_pb.TensorProto.INT64:
|
||||
new_output_name = g.output[0].name + '_cast'
|
||||
shape = g.output[0].type.tensor_type.shape
|
||||
cast_node = onnx.helper.make_node('Cast', [g.output[0].name], [new_output_name],
|
||||
to=onnx.onnx_pb.TensorProto.INT64)
|
||||
new_output = [onnx.helper.make_tensor_value_info(
|
||||
new_output_name, onnx.onnx_pb.TensorProto.INT64, None)] + list(g.output)[1:]
|
||||
if shape is not None:
|
||||
new_output[0].type.tensor_type.shape.CopyFrom(shape)
|
||||
g.node.append(cast_node)
|
||||
g.ClearField('output')
|
||||
g.output.extend(new_output)
|
||||
|
||||
return g
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
|
|
|
@ -41,6 +41,8 @@ def gen_processing_models(processor: Union[str, object],
|
|||
has to be provided in the kwargs
|
||||
pre_kwargs: dict
|
||||
Keyword arguments for generating the pre-processing model
|
||||
WITH_DEFAULT_INPUTS: bool, add default inputs to the graph, default is True
|
||||
CAST_TOKEN_ID: bool, add a cast op to output token IDs to be int64 if needed, default is False
|
||||
post_kwargs: dict
|
||||
Keyword arguments for generating the post-processing model
|
||||
opset: int
|
||||
|
@ -54,7 +56,8 @@ def gen_processing_models(processor: Union[str, object],
|
|||
The pre- and post-processing ONNX models
|
||||
"""
|
||||
if pre_kwargs is None and post_kwargs is None:
|
||||
raise ValueError("Either pre_kwargs or post_kwargs should be provided. None means no processing")
|
||||
raise ValueError(
|
||||
"Either pre_kwargs or post_kwargs should be provided. None means no processing graph output.")
|
||||
if isinstance(processor, str):
|
||||
g_pre, g_post = (None, None)
|
||||
if pre_kwargs:
|
||||
|
@ -64,7 +67,8 @@ def gen_processing_models(processor: Union[str, object],
|
|||
cls_name = processor
|
||||
else:
|
||||
if processor not in _PRE_POST_PAIR:
|
||||
raise RuntimeError(f"Cannot locate the post processing operator name from {processor}")
|
||||
raise RuntimeError(
|
||||
f"Cannot locate the post processing operator name from {processor}")
|
||||
cls_name = _PRE_POST_PAIR[processor]
|
||||
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
|
||||
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
|
||||
|
@ -72,15 +76,20 @@ def gen_processing_models(processor: Union[str, object],
|
|||
cls_name = type(processor).__name__
|
||||
if cls_name == "WhisperProcessor":
|
||||
if WhisperDataProcGraph is None:
|
||||
raise ValueError("The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
|
||||
raise ValueError(
|
||||
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
|
||||
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
|
||||
pre_m = _converter.pre_processing(**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_m = _converter.post_processing(**post_kwargs) if post_kwargs is not None else None
|
||||
pre_m = _converter.pre_processing(
|
||||
**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_m = _converter.post_processing(
|
||||
**post_kwargs) if post_kwargs is not None else None
|
||||
return pre_m, post_m
|
||||
elif HFTokenizerOnnxGraph.is_supported(processor):
|
||||
_converter = HFTokenizerOnnxGraph(processor)
|
||||
pre_g = _converter.pre_processing(**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_g = _converter.post_processing(**post_kwargs) if post_kwargs is not None else None
|
||||
pre_g = _converter.pre_processing(
|
||||
**pre_kwargs) if pre_kwargs is not None else None
|
||||
post_g = _converter.post_processing(
|
||||
**post_kwargs) if post_kwargs is not None else None
|
||||
return make_onnx_model(pre_g) if pre_g else None, \
|
||||
make_onnx_model(post_g) if post_g else None
|
||||
else:
|
||||
|
|
|
@ -112,11 +112,13 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
|
|||
all_initializers = initializers + top_level_initializers
|
||||
|
||||
# Filter the initializers by checking if their names are in the list of used input tensors
|
||||
used_initializers = [init for init in all_initializers if init.name in input_tensors]
|
||||
used_initializers = [
|
||||
init for init in all_initializers if init.name in input_tensors]
|
||||
|
||||
# Update the subgraph's initializers
|
||||
del subgraph.initializer[:]
|
||||
subgraph.initializer.extend([init for init in used_initializers if init in initializers])
|
||||
subgraph.initializer.extend(
|
||||
[init for init in used_initializers if init in initializers])
|
||||
|
||||
# Recursively process subgraphs within this subgraph
|
||||
for node in nodes:
|
||||
|
@ -125,7 +127,8 @@ def remove_unused_initializers(subgraph, top_level_initializers=None):
|
|||
remove_unused_initializers(attr.g, top_level_initializers)
|
||||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for subgraph in attr.graphs:
|
||||
remove_unused_initializers(subgraph, top_level_initializers)
|
||||
remove_unused_initializers(
|
||||
subgraph, top_level_initializers)
|
||||
|
||||
|
||||
def quick_merge(*models, connection_indices=None):
|
||||
|
@ -150,12 +153,14 @@ def quick_merge(*models, connection_indices=None):
|
|||
merged_graph = models[0].graph
|
||||
|
||||
# Dictionary to store unique opsets
|
||||
opset_imports = {opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
|
||||
opset_imports = {
|
||||
opset.domain if opset.domain else "ai.onnx": opset for opset in models[0].opset_import}
|
||||
|
||||
# Iterate over all other models and merge
|
||||
for model_idx, model in enumerate(models[1:], start=1):
|
||||
if connection_indices is None:
|
||||
io_map = [(out.name, in_.name) for out, in_ in zip(models[model_idx - 1].graph.output, model.graph.input)]
|
||||
io_map = [(out.name, in_.name) for out, in_ in zip(
|
||||
models[model_idx - 1].graph.output, model.graph.input)]
|
||||
else:
|
||||
io_map = [(models[model_idx - 1].graph.output[out_idx].name, model.graph.input[in_idx].name)
|
||||
for out_idx, in_idx in connection_indices[model_idx - 1]]
|
||||
|
@ -174,7 +179,8 @@ def quick_merge(*models, connection_indices=None):
|
|||
|
||||
default_opset = opset_imports.pop("ai.onnx", None)
|
||||
merged_model = onnx.helper.make_model_gen_version(merged_graph,
|
||||
opset_imports=[default_opset],
|
||||
opset_imports=[
|
||||
default_opset],
|
||||
producer_name='ONNX Model Merger')
|
||||
merged_model.opset_import.extend(opset_imports.values())
|
||||
return merged_model
|
||||
|
|
|
@ -28,6 +28,20 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_llama_tokenizer_id64(self):
|
||||
# replace the official model name after the model is not gated anymore
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
text = "I was born in 92000, and this is falsé."
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
||||
tokenizer,
|
||||
pre_kwargs={"WITH_DEFAULT_INPUTS": True,
|
||||
"CAST_TOKEN_ID": True})[0])
|
||||
actual_ids = ort_tok([text])[0]
|
||||
self.assertEqual(actual_ids.dtype, np.int64)
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_falcon_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b", use_fast=False)
|
||||
text = "why don't you teach me some German?"
|
||||
|
|
Загрузка…
Ссылка в новой задаче