Add Llama and Llama 2 tokenization supports (#499)
This commit is contained in:
Родитель
01d3905801
Коммит
b8bac85ecd
|
@ -30,6 +30,10 @@ class CustomOp:
|
|||
def get_outputs(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def input_default_values(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
"""
|
||||
|
@ -312,6 +316,17 @@ class SentencepieceTokenizer(CustomOp):
|
|||
cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
|
||||
]
|
||||
|
||||
# beyond Python 3.7, the order of the dict is guaranteed to be insertion order
|
||||
@classmethod
|
||||
def input_default_values(cls):
|
||||
return {
|
||||
'nbest_size': [0],
|
||||
'alpha': [0],
|
||||
'add_bos': [False],
|
||||
'add_eos': [False],
|
||||
'reverse': [False]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [
|
||||
|
|
|
@ -8,7 +8,10 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
|||
"""
|
||||
|
||||
import json
|
||||
import onnx
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from collections import namedtuple
|
||||
|
||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||
from .util import read_file
|
||||
|
@ -23,7 +26,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
|
@ -58,7 +61,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
attrs = {'vocab': json.dumps(
|
||||
hf_clip_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_clip_tokenizer.bpe_ranks.items()}
|
||||
v_ in hf_clip_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
|
@ -69,36 +72,48 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
attrs = {'vocab': json.dumps(
|
||||
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
|
||||
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def t5_tokenizer(self, **kwargs):
|
||||
def spm_tokenizer(self, **kwargs):
|
||||
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def t5_decoder(self, **kwargs):
|
||||
def spm_decoder(self, **kwargs):
|
||||
attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
||||
TokenOpParam = namedtuple("TokenOpParam",
|
||||
["pre_op", "pre_attribute_cvt",
|
||||
"post_op", "post_attribute_cvt",
|
||||
"default_inputs"],
|
||||
defaults=(None, None, None, None, None))
|
||||
|
||||
# fmt: off
|
||||
_PROCESSOR_DICT = {
|
||||
"GPT2Tokenizer": ('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"ClipTokenizer": ('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"RobertaTokenizer": ("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None),
|
||||
"T5Tokenizer": ("SentencepieceTokenizer", HFTokenizerConverter.t5_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.t5_decoder),
|
||||
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder),
|
||||
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None),
|
||||
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_eos': [True]}),
|
||||
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True]}),
|
||||
}
|
||||
|
||||
# fmt: on
|
||||
|
||||
class HFTokenizerOnnxGraph:
|
||||
|
||||
@staticmethod
|
||||
def extract_cls_name(processor):
|
||||
cls_name = processor if isinstance(processor, str) else type(processor).__name__
|
||||
|
@ -117,13 +132,40 @@ class HFTokenizerOnnxGraph:
|
|||
self.cvt_obj = HFTokenizerConverter(processor)
|
||||
|
||||
def pre_processing(self, **kwargs):
|
||||
_cvt_op = self.cvt_quadruple[0]
|
||||
_cvt_func = self.cvt_quadruple[1]
|
||||
with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)
|
||||
_cvt_op = self.cvt_quadruple.pre_op
|
||||
_cvt_func = self.cvt_quadruple.pre_attribute_cvt
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
if with_default_inputs:
|
||||
op_class = SingleOpGraph.get_op_class(_cvt_op)
|
||||
default_inputs = op_class.input_default_values()
|
||||
if default_inputs is None:
|
||||
raise ValueError("The op {} doesn't define default inputs".format(_cvt_op))
|
||||
n_inputs = len(default_inputs)
|
||||
if self.cvt_quadruple.default_inputs is not None:
|
||||
default_inputs.update(self.cvt_quadruple.default_inputs)
|
||||
if len(default_inputs) != n_inputs:
|
||||
raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op))
|
||||
|
||||
new_initializers = []
|
||||
|
||||
for k, v in default_inputs.items():
|
||||
input_value_info = next((i for i in g.input if i.name == k), None)
|
||||
if input_value_info is None:
|
||||
raise ValueError("The input {} is not found in the graph".format(k))
|
||||
|
||||
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type)
|
||||
value = np.array(v, np_dtype)
|
||||
new_initializers.append(onnx.numpy_helper.from_array(value, k))
|
||||
g.initializer.extend(new_initializers)
|
||||
new_inputs = [i for i in g.input if i.name not in default_inputs]
|
||||
g.ClearField("input")
|
||||
g.input.extend(new_inputs)
|
||||
return g
|
||||
|
||||
def post_processing(self, **kwargs):
|
||||
_cvt_op = self.cvt_quadruple[2]
|
||||
_cvt_func = self.cvt_quadruple[3]
|
||||
_cvt_op = self.cvt_quadruple.post_op
|
||||
_cvt_func = self.cvt_quadruple.post_attribute_cvt
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
|
|
|
@ -2,39 +2,36 @@
|
|||
# Licensed under the MIT License.
|
||||
import sys
|
||||
import unittest
|
||||
import transformers as _hfts
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
from packaging import version
|
||||
from transformers import AutoTokenizer, WhisperProcessor
|
||||
from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models
|
||||
|
||||
|
||||
@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0")
|
||||
class TestAutoTokenizer(unittest.TestCase):
|
||||
def test_llama_tokenizer(self):
|
||||
# replace the official model name after the model is not gated anymore
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np")
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
||||
tokenizer,
|
||||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||
actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_t5_tokenizer(self):
|
||||
tokenizer = _hfts.AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
||||
ids = tokenizer.encode("best hotel in bay area.", return_tensors="np")
|
||||
print(ids)
|
||||
|
||||
alpha = 0
|
||||
nbest_size = 0
|
||||
flags = 0
|
||||
|
||||
t5_default_inputs = (
|
||||
np.array(
|
||||
[nbest_size], dtype=np.int64),
|
||||
np.array([alpha], dtype=np.float32),
|
||||
np.array([flags & 1], dtype=np.bool_),
|
||||
np.array([flags & 2], dtype=np.bool_),
|
||||
np.array([flags & 4], dtype=np.bool_))
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0])
|
||||
actual_ids = ort_tok(["best hotel in bay area."], *t5_default_inputs)[0]
|
||||
np.testing.assert_array_equal(ids[0][:-1], actual_ids)
|
||||
actual_ids = ort_tok(["best hotel in bay area."])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_whisper_overall(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, post_m = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False},
|
||||
post_kwargs={})
|
||||
|
@ -51,7 +48,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
self.assertEqual(rel[0], "$%&")
|
||||
|
||||
def test_whisper_audio_decoder(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True})
|
||||
|
||||
|
@ -64,7 +61,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_ort_stft_consistency(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True})
|
||||
|
||||
|
@ -86,7 +83,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_stft_norm_consistency(self):
|
||||
processor = _hfts.WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче