Add Llama and Llama 2 tokenization supports (#499)

This commit is contained in:
Wenbing Li 2023-07-26 10:22:00 -07:00 коммит произвёл GitHub
Родитель 01d3905801
Коммит b8bac85ecd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 95 добавлений и 41 удалений

Просмотреть файл

@ -30,6 +30,10 @@ class CustomOp:
def get_outputs(cls):
return None
@classmethod
def input_default_values(cls):
return None
@classmethod
def serialize_attr(cls, attrs):
"""
@ -312,6 +316,17 @@ class SentencepieceTokenizer(CustomOp):
cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
]
# beyond Python 3.7, the order of the dict is guaranteed to be insertion order
@classmethod
def input_default_values(cls):
return {
'nbest_size': [0],
'alpha': [0],
'add_bos': [False],
'add_eos': [False],
'reverse': [False]
}
@classmethod
def get_outputs(cls):
return [

Просмотреть файл

@ -8,7 +8,10 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
"""
import json
import onnx
import numpy as np
from functools import partial
from collections import namedtuple
from ._cuops import CustomOpConverter, SingleOpGraph
from .util import read_file
@ -23,7 +26,7 @@ class HFTokenizerConverter(CustomOpConverter):
attrs = {'vocab': json.dumps(
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs.update(**kwargs)
@ -58,7 +61,7 @@ class HFTokenizerConverter(CustomOpConverter):
attrs = {'vocab': json.dumps(
hf_clip_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_clip_tokenizer.bpe_ranks.items()}
v_ in hf_clip_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs.update(**kwargs)
@ -69,36 +72,48 @@ class HFTokenizerConverter(CustomOpConverter):
attrs = {'vocab': json.dumps(
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs.update(**kwargs)
return attrs
def t5_tokenizer(self, **kwargs):
def spm_tokenizer(self, **kwargs):
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
attrs.update(**kwargs)
return attrs
def t5_decoder(self, **kwargs):
def spm_decoder(self, **kwargs):
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
attrs.update(**kwargs)
return attrs
TokenOpParam = namedtuple("TokenOpParam",
["pre_op", "pre_attribute_cvt",
"post_op", "post_attribute_cvt",
"default_inputs"],
defaults=(None, None, None, None, None))
# fmt: off
_PROCESSOR_DICT = {
"GPT2Tokenizer": ('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"ClipTokenizer": ('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"RobertaTokenizer": ("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
None, None),
"T5Tokenizer": ("SentencepieceTokenizer", HFTokenizerConverter.t5_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.t5_decoder),
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
None, None),
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_eos': [True]}),
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True]}),
}
# fmt: on
class HFTokenizerOnnxGraph:
@staticmethod
def extract_cls_name(processor):
cls_name = processor if isinstance(processor, str) else type(processor).__name__
@ -117,13 +132,40 @@ class HFTokenizerOnnxGraph:
self.cvt_obj = HFTokenizerConverter(processor)
def pre_processing(self, **kwargs):
_cvt_op = self.cvt_quadruple[0]
_cvt_func = self.cvt_quadruple[1]
with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)
_cvt_op = self.cvt_quadruple.pre_op
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
if with_default_inputs:
op_class = SingleOpGraph.get_op_class(_cvt_op)
default_inputs = op_class.input_default_values()
if default_inputs is None:
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op))
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if len(default_inputs) != n_inputs:
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
new_initializers = []
for k, v in default_inputs.items():
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError("The input {} is not found in the graph".format(k))
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
value = np.array(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)
return g
def post_processing(self, **kwargs):
_cvt_op = self.cvt_quadruple[2]
_cvt_func = self.cvt_quadruple[3]
_cvt_op = self.cvt_quadruple.post_op
_cvt_func = self.cvt_quadruple.post_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)

Просмотреть файл

@ -2,39 +2,36 @@
# Licensed under the MIT License.
import sys
import unittest
import transformers as _hfts
import numpy as np
import onnxruntime as _ort
from packaging import version
from transformers import AutoTokenizer, WhisperProcessor
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
class TestAutoTokenizer(unittest.TestCase):
def test_llama_tokenizer(self):
# replace the official model name after the model is not gated anymore
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np")
ort_tok = OrtPyFunction.from_model(gen_processing_models(
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0]
np.testing.assert_array_equal(ids[0], actual_ids)
def test_t5_tokenizer(self):
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
print(ids)
alpha = 0
nbest_size = 0
flags = 0
t5_default_inputs = (
np.array(
[nbest_size], dtype=np.int64),
np.array([alpha], dtype=np.float32),
np.array([flags & 1], dtype=np.bool_),
np.array([flags & 2], dtype=np.bool_),
np.array([flags & 4], dtype=np.bool_))
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
actual_ids = ort_tok(["best hotel in bay area."])[0]
np.testing.assert_array_equal(ids[0], actual_ids)
def test_whisper_overall(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, post_m = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
post_kwargs={})
@ -51,7 +48,7 @@ class TestAutoTokenizer(unittest.TestCase):
self.assertEqual(rel[0], "$%&")
def test_whisper_audio_decoder(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
@ -64,7 +61,7 @@ class TestAutoTokenizer(unittest.TestCase):
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
@ -86,7 +83,7 @@ class TestAutoTokenizer(unittest.TestCase):
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})