update the bert end to end example with hftok (#236)

This commit is contained in:
Wenbing Li 2022-06-01 10:41:42 -07:00 коммит произвёл GitHub
Родитель 49548f843d
Коммит da4784a2cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 10 добавлений и 8 удалений

Просмотреть файл

@ -1,4 +1,5 @@
import json import json
from collections import OrderedDict
from ._base import ProcessingTracedModule, tensor_data_type as _dt from ._base import ProcessingTracedModule, tensor_data_type as _dt
from ._torchext import create_op_function from ._torchext import create_op_function
@ -17,7 +18,9 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs): def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
if 'hf_tok' in kwargs: if 'hf_tok' in kwargs:
hf_bert_tokenizer = kwargs['hf_tok'] hf_bert_tokenizer = kwargs['hf_tok']
attrs = {'vocab_file': hf_bert_tokenizer.vocab} ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
vocab = '\n'.join(ordered_vocab.keys())
attrs = dict(vocab_file=vocab)
elif 'vocab_file' in kwargs: elif 'vocab_file' in kwargs:
vocab = None vocab = None
vocab_file = kwargs['vocab_file'] vocab_file = kwargs['vocab_file']

Просмотреть файл

@ -1,6 +1,5 @@
import onnx import onnx
import torch import torch
import onnxruntime_extensions
from pathlib import Path from pathlib import Path
from onnxruntime_extensions import pnp, OrtPyFunction from onnxruntime_extensions import pnp, OrtPyFunction
@ -10,13 +9,13 @@ from transformers.onnx import export, FeaturesManager
# get an onnx model by converting HuggingFace pretrained model # get an onnx model by converting HuggingFace pretrained model
model_name = "bert-base-cased" model_name = "bert-base-cased"
model_path = Path("onnx-model/bert-base-cased.onnx") model_path = Path("onnx-model/bert-base-cased.onnx")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if not model_path.exists(): if not model_path.exists():
if not model_path.parent.exists(): if not model_path.parent.exists():
model_path.parent.mkdir(parents=True, exist_ok=True) model_path.parent.mkdir(parents=True, exist_ok=True)
model = FeaturesManager.get_model_from_feature("default", model_name) model = FeaturesManager.get_model_from_feature("default", model_name)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default") model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
onnx_config = model_onnx_config(model.config) onnx_config = model_onnx_config(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
export(tokenizer, export(tokenizer,
model=model, model=model,
config=onnx_config, config=onnx_config,
@ -35,20 +34,20 @@ def mapping_token_output(_1, _2, _3):
test_sentence = ["this is a test sentence."] test_sentence = ["this is a test sentence."]
ort_tok = pnp.PreHuggingFaceBert( ort_tok = pnp.PreHuggingFaceBert(hf_tok=tokenizer)
vocab_file=onnxruntime_extensions.get_test_data_file(
'../test', 'data', 'bert_basic_cased_vocab.txt'))
onnx_model = onnx.load_model(str(model_path)) onnx_model = onnx.load_model(str(model_path))
augmented_model_name = 'temp_bert_tok_all.onnx'
# create the final onnx model which includes pre- and post- processing. # create the final onnx model which includes pre- and post- processing.
augmented_model = pnp.export(pnp.SequentialProcessingModule( augmented_model = pnp.export(pnp.SequentialProcessingModule(
ort_tok, mapping_token_output, ort_tok, mapping_token_output,
onnx_model, post_processing_forward), onnx_model, post_processing_forward),
test_sentence, test_sentence,
opset_version=12, opset_version=12,
output_path='bert_tok_all.onnx') output_path=augmented_model_name)
# test the augmented onnx model with raw string input. # test the augmented onnx model with raw string input.
model_func = OrtPyFunction.from_model('bert_tok_all.onnx') model_func = OrtPyFunction.from_model(augmented_model_name)
result = model_func(test_sentence) result = model_func(test_sentence)
print(result) print(result)