update the bert end to end example with hftok (#236)
This commit is contained in:
Родитель
49548f843d
Коммит
da4784a2cc
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from ._base import ProcessingTracedModule, tensor_data_type as _dt
|
from ._base import ProcessingTracedModule, tensor_data_type as _dt
|
||||||
from ._torchext import create_op_function
|
from ._torchext import create_op_function
|
||||||
|
@ -17,7 +18,9 @@ def make_custom_op(ctx, op_type, input_names, output_names, container, operator_
|
||||||
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
def bert_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
||||||
if 'hf_tok' in kwargs:
|
if 'hf_tok' in kwargs:
|
||||||
hf_bert_tokenizer = kwargs['hf_tok']
|
hf_bert_tokenizer = kwargs['hf_tok']
|
||||||
attrs = {'vocab_file': hf_bert_tokenizer.vocab}
|
ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
|
||||||
|
vocab = '\n'.join(ordered_vocab.keys())
|
||||||
|
attrs = dict(vocab_file=vocab)
|
||||||
elif 'vocab_file' in kwargs:
|
elif 'vocab_file' in kwargs:
|
||||||
vocab = None
|
vocab = None
|
||||||
vocab_file = kwargs['vocab_file']
|
vocab_file = kwargs['vocab_file']
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
import onnxruntime_extensions
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from onnxruntime_extensions import pnp, OrtPyFunction
|
from onnxruntime_extensions import pnp, OrtPyFunction
|
||||||
|
@ -10,13 +9,13 @@ from transformers.onnx import export, FeaturesManager
|
||||||
# get an onnx model by converting HuggingFace pretrained model
|
# get an onnx model by converting HuggingFace pretrained model
|
||||||
model_name = "bert-base-cased"
|
model_name = "bert-base-cased"
|
||||||
model_path = Path("onnx-model/bert-base-cased.onnx")
|
model_path = Path("onnx-model/bert-base-cased.onnx")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if not model_path.parent.exists():
|
if not model_path.parent.exists():
|
||||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
model = FeaturesManager.get_model_from_feature("default", model_name)
|
model = FeaturesManager.get_model_from_feature("default", model_name)
|
||||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature="default")
|
||||||
onnx_config = model_onnx_config(model.config)
|
onnx_config = model_onnx_config(model.config)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
export(tokenizer,
|
export(tokenizer,
|
||||||
model=model,
|
model=model,
|
||||||
config=onnx_config,
|
config=onnx_config,
|
||||||
|
@ -35,20 +34,20 @@ def mapping_token_output(_1, _2, _3):
|
||||||
|
|
||||||
|
|
||||||
test_sentence = ["this is a test sentence."]
|
test_sentence = ["this is a test sentence."]
|
||||||
ort_tok = pnp.PreHuggingFaceBert(
|
ort_tok = pnp.PreHuggingFaceBert(hf_tok=tokenizer)
|
||||||
vocab_file=onnxruntime_extensions.get_test_data_file(
|
|
||||||
'../test', 'data', 'bert_basic_cased_vocab.txt'))
|
|
||||||
onnx_model = onnx.load_model(str(model_path))
|
onnx_model = onnx.load_model(str(model_path))
|
||||||
|
|
||||||
|
|
||||||
|
augmented_model_name = 'temp_bert_tok_all.onnx'
|
||||||
# create the final onnx model which includes pre- and post- processing.
|
# create the final onnx model which includes pre- and post- processing.
|
||||||
augmented_model = pnp.export(pnp.SequentialProcessingModule(
|
augmented_model = pnp.export(pnp.SequentialProcessingModule(
|
||||||
ort_tok, mapping_token_output,
|
ort_tok, mapping_token_output,
|
||||||
onnx_model, post_processing_forward),
|
onnx_model, post_processing_forward),
|
||||||
test_sentence,
|
test_sentence,
|
||||||
opset_version=12,
|
opset_version=12,
|
||||||
output_path='bert_tok_all.onnx')
|
output_path=augmented_model_name)
|
||||||
|
|
||||||
# test the augmented onnx model with raw string input.
|
# test the augmented onnx model with raw string input.
|
||||||
model_func = OrtPyFunction.from_model('bert_tok_all.onnx')
|
model_func = OrtPyFunction.from_model(augmented_model_name)
|
||||||
result = model_func(test_sentence)
|
result = model_func(test_sentence)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче