diff --git a/onnxruntime_extensions/pnp/_nlp.py b/onnxruntime_extensions/pnp/_nlp.py index f8f89fa0..f8447910 100644 --- a/onnxruntime_extensions/pnp/_nlp.py +++ b/onnxruntime_extensions/pnp/_nlp.py @@ -16,16 +16,21 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_ outputs=((_dt.INT64, []), (_dt.INT64, []), (_dt.INT64, []))) def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs): if 'hf_tok' in kwargs: - # TODO: need bert-tokenizer support JSON format hf_bert_tokenizer = kwargs['hf_tok'] - attrs = {'vocab_file': json.dumps(hf_bert_tokenizer.vocab, separators=(',', ':'))} + attrs = {'vocab_file': hf_bert_tokenizer.vocab} elif 'vocab_file' in kwargs: - attrs = dict(vocab_file=kwargs['vocab_file']) + vocab = None + vocab_file = kwargs['vocab_file'] + with open(vocab_file, "r", encoding='utf-8') as vf: + lines = vf.readlines() + vocab = '\n'.join(lines) + if vocab is None: + raise RuntimeError("Cannot load vocabulary file {}!".format(vocab_file)) + attrs = dict(vocab_file=vocab) else: raise RuntimeError("Need hf_tok/vocab_file parameter to build the tokenizer") if 'strip_accents' in kwargs: - strip_accents = kwargs['strip_accents'] - attrs['strip_accents'] = strip_accents + attrs['strip_accents'] = kwargs['strip_accents'] return make_custom_op(ctx, 'BertTokenizer', input_names, output_names, container, operator_name=operator_name, **attrs) @@ -67,18 +72,12 @@ class PreHuggingFaceBert(ProcessingTracedModule): def __init__(self, hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1): super(PreHuggingFaceBert, self).__init__() if hf_tok is None: - _vocab = None - with open(vocab_file, "r", encoding='utf-8') as vf: - lines = vf.readlines() - _vocab = '\n'.join(lines) - if _vocab is None: - raise RuntimeError("Cannot load vocabulary file {}!".format(vocab_file)) self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, - vocab_file=_vocab, + vocab_file=vocab_file, do_lower_case=do_lower_case, strip_accents=strip_accents) else: - self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, hf_tok=self.hf_tok) + self.onnx_bert_tokenize = create_op_function('BertTokenizer', bert_tokenize, hf_tok=hf_tok) def forward(self, text): return self.onnx_bert_tokenize(text) @@ -96,7 +95,7 @@ class PreHuggingFaceGPT2(ProcessingTracedModule): merges=_get_file_content(merges_file), padding_length=padding_length) else: - self.onnx_gpt2_tokenize = create_op_function('GPT2Tokenizer', gpt2_tokenize, hf_tok=self.hf_tok) + self.onnx_gpt2_tokenize = create_op_function('GPT2Tokenizer', gpt2_tokenize, hf_tok=hf_tok) def forward(self, text): return self.onnx_gpt2_tokenize(text)