Fix HF Fast Tokenizer cvt issue for AutoTokenizer imp (#520)

* Fix GPT2 and Falcon tokenizer cvt for AutoTokenizer imp

* fix fast tokenizer issue

* small fix

* use slow tokenizer in test script

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Sayan Shaw 2023-08-11 13:17:56 -07:00 коммит произвёл GitHub
Родитель cd416e2ab4
Коммит 9ba649e134
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 17 добавлений и 7 удалений

Просмотреть файл

@ -23,6 +23,10 @@ class HFTokenizerConverter(CustomOpConverter):
def bpe_tokenizer(self, **kwargs):
hf_gpt2_tokenizer = self.tokenizer
if ("Fast" in str(self.tokenizer)):
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
attrs = {'vocab': json.dumps(
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
@ -77,6 +81,10 @@ class HFTokenizerConverter(CustomOpConverter):
def clip_tokenizer(self, **kwargs):
hf_clip_tokenizer = self.tokenizer
if ("Fast" in str(self.tokenizer)):
raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
attrs = {'vocab': json.dumps(
hf_clip_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
@ -88,6 +96,10 @@ class HFTokenizerConverter(CustomOpConverter):
def roberta_tokenizer(self, **kwargs):
hf_roberta_tokenizer = self.tokenizer
if ("Fast" in str(self.tokenizer)):
raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
attrs = {'vocab': json.dumps(
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
@ -121,7 +133,7 @@ _PROCESSOR_DICT = {
"DistilBertTokenizer":
TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
"GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
"ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
@ -132,9 +144,7 @@ _PROCESSOR_DICT = {
default_inputs={'add_eos': [True]}),
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True]}),
"FalconTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None)
default_inputs={'add_bos': [True]})
}
# @formatter:on
@ -198,4 +208,4 @@ class HFTokenizerOnnxGraph:
_cvt_op = self.cvt_quadruple.post_op
_cvt_func = self.cvt_quadruple.post_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)

Просмотреть файл

@ -30,7 +30,7 @@ class TestAutoTokenizer(unittest.TestCase):
def test_falcon_tokenizer(self):
# replace the official model name after the model is not gated anymore
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b", use_fast=False)
text = "why don't you teach me some German?"
ids = tokenizer.encode(text, return_tensors="np")
@ -38,7 +38,7 @@ class TestAutoTokenizer(unittest.TestCase):
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids[0], actual_ids)
np.testing.assert_array_equal(ids, actual_ids)
def test_t5_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)