Add more HF tokenizer supports in gen_processing_models (#531)
This commit is contained in:
Родитель
29c6d66c02
Коммит
396044310e
|
@ -24,13 +24,12 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
|
||||
if ("Fast" in str(self.tokenizer)):
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
|
@ -58,13 +57,9 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def bpe_decoder(self, **kwargs):
|
||||
decoder = self.tokenizer.decoder
|
||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||
# with open("id_vocab.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(id_vocab)
|
||||
byte_decoder = self.tokenizer.byte_decoder
|
||||
str_byte_decoder = "\n".join(["{}\t{}".format(
|
||||
ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
|
||||
# with open("byte_decoder.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(str_byte_decoder)
|
||||
all_special_ids = self.tokenizer.all_special_ids
|
||||
added_tokens = self.tokenizer.added_tokens_decoder
|
||||
str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
|
||||
|
@ -82,7 +77,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def clip_tokenizer(self, **kwargs):
|
||||
hf_clip_tokenizer = self.tokenizer
|
||||
|
||||
if ("Fast" in str(self.tokenizer)):
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
|
@ -97,7 +92,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def roberta_tokenizer(self, **kwargs):
|
||||
hf_roberta_tokenizer = self.tokenizer
|
||||
|
||||
if ("Fast" in str(self.tokenizer)):
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
|
@ -126,25 +121,38 @@ TokenOpParam = namedtuple("TokenOpParam",
|
|||
"default_inputs"],
|
||||
defaults=(None, None, None, None, None))
|
||||
|
||||
# Some tokenizers can be added by this table
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1252
|
||||
# @formatter:off
|
||||
_PROCESSOR_DICT = {
|
||||
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"DistilBertTokenizer":
|
||||
TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer,
|
||||
None, None, None),
|
||||
"T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_eos': [True]}),
|
||||
"LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer,
|
||||
"SentencepieceDecoder", HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True]})
|
||||
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
|
||||
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"RobertaTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"BartTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"LayoutLMv3Tokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"LongformerTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"LEDTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"MvpTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
|
||||
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
|
||||
"T5Tokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
|
||||
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_eos': [True]}),
|
||||
"LlamaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
|
||||
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
|
||||
default_inputs={'add_bos': [True]})
|
||||
}
|
||||
# @formatter:on
|
||||
|
||||
|
@ -208,4 +216,4 @@ class HFTokenizerOnnxGraph:
|
|||
_cvt_op = self.cvt_quadruple.post_op
|
||||
_cvt_func = self.cvt_quadruple.post_attribute_cvt
|
||||
cvt = partial(_cvt_func, self.cvt_obj)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
|
||||
|
|
|
@ -28,7 +28,7 @@ KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKerne
|
|||
std::stringstream vocabu_stream(vocab);
|
||||
std::stringstream merges_stream(merges);
|
||||
bbpe_tokenizer_ = std::make_shared<VocabData>();
|
||||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
|
||||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|startoftext|>\n<|endoftext|>");
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping,
|
||||
|
|
|
@ -28,7 +28,7 @@ KernelRobertaBpeTokenizer::KernelRobertaBpeTokenizer(const OrtApi& api, const Or
|
|||
std::stringstream vocabu_stream(vocab);
|
||||
std::stringstream merges_stream(merges);
|
||||
bbpe_tokenizer_ = std::make_shared<VocabData>();
|
||||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
|
||||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<unk>", "<s>\n</s>\n<pad>\n<mask>");
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t max_length,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
|
|
|
@ -48,6 +48,17 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
def test_roberta_base(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("SamLowe/roberta-base-go_emotions", use_fast=False)
|
||||
text = "Agree. Keep trying, then if your rejected every time. I'm sorry your done."
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
m_tok, m_detok = gen_processing_models(tokenizer, pre_kwargs={}, post_kwargs={})
|
||||
|
||||
actual_ids = OrtPyFunction(m_tok)([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
self.assertEqual(OrtPyFunction(m_detok)(ids)[0], tokenizer.decode(ids[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
|
Загрузка…
Ссылка в новой задаче