update the bert end to end example with hftok (#236)
This commit is contained in:
Родитель
49548f843d
Коммит
da4784a2cc
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from ._base import ProcessingTracedModule, tensor_data_type as _dt
|
||||
from ._torchext import create_op_function
|
||||
|
@ -17,7 +18,9 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_
|
|||
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
||||
if 'hf_tok' in kwargs:
|
||||
hf_bert_tokenizer = kwargs['hf_tok']
|
||||
attrs = {'vocab_file': hf_bert_tokenizer.vocab}
|
||||
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||
vocab = '\n'.join(ordered_vocab.keys())
|
||||
attrs = dict(vocab_file=vocab)
|
||||
elif 'vocab_file' in kwargs:
|
||||
vocab = None
|
||||
vocab_file = kwargs['vocab_file']
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import onnx
|
||||
import torch
|
||||
import onnxruntime_extensions
|
||||
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import pnp, OrtPyFunction
|
||||
|
@ -10,13 +9,13 @@ from transformers.onnx import export, FeaturesManager
|
|||
# get an onnx model by converting HuggingFace pretrained model
|
||||
model_name = "bert-base-cased"
|
||||
model_path = Path("onnx-model/bert-base-cased.onnx")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if not model_path.exists():
|
||||
if not model_path.parent.exists():
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
model = FeaturesManager.get_model_from_feature("default", model_name)
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
export(tokenizer,
|
||||
model=model,
|
||||
config=onnx_config,
|
||||
|
@ -35,20 +34,20 @@ def mapping_token_output(_1, _2, _3):
|
|||
|
||||
|
||||
test_sentence = ["this is a test sentence."]
|
||||
ort_tok = pnp.PreHuggingFaceBert(
|
||||
vocab_file=onnxruntime_extensions.get_test_data_file(
|
||||
'../test', 'data', 'bert_basic_cased_vocab.txt'))
|
||||
ort_tok = pnp.PreHuggingFaceBert(hf_tok=tokenizer)
|
||||
onnx_model = onnx.load_model(str(model_path))
|
||||
|
||||
|
||||
augmented_model_name = 'temp_bert_tok_all.onnx'
|
||||
# create the final onnx model which includes pre- and post- processing.
|
||||
augmented_model = pnp.export(pnp.SequentialProcessingModule(
|
||||
ort_tok, mapping_token_output,
|
||||
onnx_model, post_processing_forward),
|
||||
test_sentence,
|
||||
opset_version=12,
|
||||
output_path='bert_tok_all.onnx')
|
||||
output_path=augmented_model_name)
|
||||
|
||||
# test the augmented onnx model with raw string input.
|
||||
model_func = OrtPyFunction.from_model('bert_tok_all.onnx')
|
||||
model_func = OrtPyFunction.from_model(augmented_model_name)
|
||||
result = model_func(test_sentence)
|
||||
print(result)
|
||||
|
|
Загрузка…
Ссылка в новой задаче