Close gap in vocab for AutoTokenizer support for GPT4Tokenizer (#567)
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Родитель
b7e35a1a34
Коммит
bcde705eec
|
@ -9,6 +9,8 @@ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
|
|||
|
||||
import json
|
||||
import onnx
|
||||
import uuid
|
||||
import numpy as np
|
||||
from numpy import array as nparray
|
||||
from functools import partial
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
@ -23,12 +25,31 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
attrs = None
|
||||
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
elif(self.tokenizer.name_or_path.endswith('gpt-4')):
|
||||
# Fill vocab gap for GPT4Tokenizer to create continuous domain
|
||||
vocab_dict = hf_gpt2_tokenizer.encoder
|
||||
partial_values = list(vocab_dict.values())
|
||||
|
||||
max_vocab = partial_values[-1]
|
||||
all_values = np.arange(max_vocab + 1)
|
||||
|
||||
missing_values = set(all_values) - set(partial_values)
|
||||
|
||||
for v in missing_values:
|
||||
vocab_dict[str(uuid.uuid4())] = int(v)
|
||||
|
||||
vocab_dict = dict(sorted(vocab_dict.items(), key=lambda item: item[1]))
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
vocab_dict, separators=(',', ':'))}
|
||||
else:
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
|
@ -130,7 +151,7 @@ _PROCESSOR_DICT = {
|
|||
"DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer
|
||||
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference, util
|
||||
|
||||
|
||||
|
@ -68,6 +68,17 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
def test_gpt2_tokenizer(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("Xenova/gpt-4", use_fast=False)
|
||||
text = "Deep learning has come a long way, no?"
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok = OrtPyFunction.from_model(gen_processing_models(
|
||||
tokenizer,
|
||||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
def test_xmlroberta_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
|
||||
|
|
Загрузка…
Ссылка в новой задаче