ort-customops/tutorials/gpt2bs.py

226 строки
10 KiB
Python

####################################################################################
#
# !!! This script is replaced by the latest onnxruntime contrib op solution, which is
# https://github.com/microsoft/onnxruntime/blob/ad9d2e2e891714e0911ccc3fa8b70f42025b4d56/onnxruntime/python/tools/transformers/convert_beam_search.py
#
###################################################################################
import os
import onnx
import numpy
import argparse
import onnxruntime as _ort
import onnxruntime_extensions as _ortex
from transformers import AutoConfig
from distutils.version import StrictVersion
if StrictVersion(_ort.__version__) < StrictVersion('1.8.1'):
raise RuntimeError('Full GPT-2 model is only available on onxruntime 1.8.1 and higher version.')
model_name_or_path = "gpt2"
device = "cpu"
default_beam_width = 4
default_batch_size = 1
onnx_model_path = "gpt2_one_step_search.onnx"
gpt2_full_model_path = "gpt2_full.onnx"
# Create a cache directory to store pretrained model.
cache_dir = os.path.expanduser('~/.cache/huggingface/')
if not os.path.exists(cache_dir):
cache_dir = os.path.join(".", "cache_models")
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
def _extract_endict(tokenizer_endict):
_1, _2 = [tokenizer_endict.get(ky_) for ky_ in ('input_ids', 'attention_mask')]
return _1.astype(numpy.int64), _2.astype(numpy.float32)
def get_tokenizer(model_name_or_path, enable_tokenizer, cache_dir):
from transformers import GPT2Tokenizer # noqa
from onnxruntime_extensions.onnxprocess import build_customop_model, pyfunc_from_model
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
if enable_tokenizer:
gpt2_encoder_model_path = './gpt2_tok.onnx'
build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer)
return tokenizer, pyfunc_from_model(gpt2_encoder_model_path)
else:
return tokenizer, None
def convert_gpt2():
from onnxruntime.transformers.gpt2_beamsearch_helper import Gpt2BeamSearchHelper, GPT2LMHeadModel_BeamSearchStep
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
model = GPT2LMHeadModel_BeamSearchStep.from_pretrained(model_name_or_path,
config=config,
batch_size=default_batch_size,
beam_size=default_beam_width,
cache_dir=cache_dir)
model.eval().to(device)
Gpt2BeamSearchHelper.export_onnx(model, device, onnx_model_path)
def inference_and_dump_full_model(tokenizer, func_tokenizer, input_text, num_tokens_to_produce=30):
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
# a hot fix for the dynamic axes of the converted model
gpt2_core = onnx.load_model(onnx_model_path)
for _vi in gpt2_core.graph.output:
if _vi.name == 'last_state':
_vi.type.tensor_type.shape.dim[1].dim_param = 'seq_len'
func_one_step = pyfunc_from_model(gpt2_core)
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
num_attention_heads = config.n_head
hidden_size = config.n_embd
num_layer = config.n_layer
if func_tokenizer is None:
input_ids, attention_mask = _extract_endict(tokenizer(input_text, padding=True, return_tensors='np'))
with trace_for_onnx(input_ids, attention_mask,
num_tokens_to_produce,
names=["input_ids", "attention_mask", "out_token_num"], target_opset=12) as tc_sess:
input_ids, attention_mask, num_tokens = tc_sess.get_inputs()
input_ids.symbolic_shape = ['batch_size', 'seq_len']
attention_mask.symbolic_shape = ['batch_size', 'seq_len']
full_model = _beam_search(tokenizer, func_one_step,
num_attention_heads, hidden_size, num_layer,
tc_sess, num_tokens, input_ids, attention_mask)
else:
with trace_for_onnx(input_text, num_tokens_to_produce,
names=func_tokenizer.input_names, target_opset=12) as tc_sess:
inputs, num_tokens = tc_sess.get_inputs()
input_ids, attention_mask = func_tokenizer(inputs, padding=True)
full_model = _beam_search(tokenizer, func_one_step,
num_attention_heads, hidden_size, num_layer,
tc_sess, num_tokens, input_ids, attention_mask)
_ortex.optimize_model(full_model, gpt2_full_model_path)
def _beam_search(tokenizer, func_one_step,
num_attention_heads, hidden_size, num_layer, tc_sess, num_tokens, input_ids, attention_mask):
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
if attention_mask.dtype is not torch.float32:
attention_mask = attention_mask.type(torch.float)
position_ids = (attention_mask.long().cumsum(-1) - 1)
batch_size = default_batch_size
past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
empty_past = []
for _ in range(num_layer):
empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device))
beam_select_idx = torch.zeros([1, batch_size]).long()
input_log_probs = torch.zeros([batch_size, 1])
input_unfinished_sents = torch.ones([batch_size, 1], dtype=torch.bool)
prev_step_scores = torch.zeros([batch_size, 1])
beam_size = default_beam_width
prev_step_results = input_ids.clone().detach().to(device)
cfg = torch.control_flow()
for states in cfg.loop(num_tokens, torch.tensor(True), input_ids, position_ids,
attention_mask, beam_select_idx, input_log_probs,
input_unfinished_sents, prev_step_results, prev_step_scores, *empty_past):
step = states[0]
states[1].symbolic_shape = ['batch_size', 'seq_len']
states[2].symbolic_shape = ['batch_size', 'seq_len']
states[3].symbolic_shape = ['batch_size', 'all_seq_len']
states[4].symbolic_shape = [1, 'batch_size']
# prev_step_results
states[7].symbolic_shape = ['batch_size', 'total_seq_len']
for st_ in states[-num_layer:]:
st_.symbolic_shape = [2, 'batch_size', num_attention_heads,
'past_seq_len', hidden_size // num_attention_heads]
prev_attention_mask = states[3]
outputs = func_one_step(*states[1:])
last_state = outputs[0].clone().detach().cpu()
input_ids = last_state.reshape([batch_size * beam_size, -1]).to(device)
input_unfinished_sents_id = -3
prev_step_results = outputs[-2].clone().detach().to(device)
# position_ids = (torch.tensor([context_length + step - 1
# ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device))
position_ids = torch.zeros([batch_size * beam_size, 1], dtype=torch.int64) + attention_mask.size()[-1]
factor = (~step.type(torch.bool)).type(torch.int64)
prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
attention_mask = torch.cat(
[
prev_attention_mask,
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
],
1,
).to(device)
beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)
input_unfinished_sents = outputs[input_unfinished_sents_id].clone().detach().to(device)
prev_step_scores = outputs[-1].clone().detach().to(device)
past = []
for i in range(num_layer):
past_i = outputs[i + 1].clone().detach()
past.append(past_i.to(device))
any_unfinished = input_unfinished_sents.any()
input_ids.symbolic_shape = ['total_batch_size', 'seq_len']
position_ids.symbolic_shape = ['total_batch_size', 'seq_len']
attention_mask.symbolic_shape = ['total_batch_size', 'all_seq_len']
prev_step_results.symbolic_shape = ['total_batch_size', 'step_seq_len']
for st_ in past:
st_.symbolic_shape = [2, 'total_batch_size', num_attention_heads,
'all_seq_len', hidden_size // num_attention_heads]
cfg.flow_output(any_unfinished, input_ids,
position_ids, attention_mask, beam_select_idx,
input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past)
result_id = 6
all_token_ids = cfg.finalize()[result_id]
mdl = tc_sess.save_as_onnx(None, all_token_ids)
print(tokenizer.decode(all_token_ids.t[0], skip_special_tokens=True))
return mdl
def verify_bsfull_model(input_text, tokenizer, enable_tokenizer):
import time
from onnxruntime_extensions import PyOrtFunction
gpt2_all = PyOrtFunction.from_model(gpt2_full_model_path)
gpt2_all._ensure_ort_session()
if enable_tokenizer:
start_time = time.perf_counter()
outputs = gpt2_all(input_text, 30)
else:
input_ids, attention_mask = _extract_endict(tokenizer(input_text, padding=True, return_tensors='np'))
start_time = time.perf_counter()
outputs = gpt2_all(input_ids, attention_mask, 30)
print("total time: {}".format(time.perf_counter() - start_time))
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
def main(enable_tokenizer):
tokenizer, func_tokenizer = get_tokenizer(model_name_or_path, enable_tokenizer, cache_dir)
input_text = ['best hotel in bay area.']
if not os.path.exists(onnx_model_path):
convert_gpt2()
inference_and_dump_full_model(tokenizer, func_tokenizer, input_text)
verify_bsfull_model(input_text, tokenizer, enable_tokenizer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--disable-tokenizer", '-d', help="No tokenizer operator for the full model",
action="store_true")
parser.add_argument("--output", '-o', help="The output file name")
args = parser.parse_args()
if args.output is not None:
gpt2_full_model_path = args.output
main(not args.disable_tokenizer)