the full beam search post processing for GPT-2 model (#94)
* the full beam search post processing for GPT-2 model * minor fixes
This commit is contained in:
Родитель
06c902253f
Коммит
352b2003bc
|
@ -23,10 +23,10 @@ output, *_ = gpt2_core(input_ids)
|
|||
next_id = numpy.argmax(output[:, :, -1, :], axis=-1)
|
||||
print(input_text[0] + decode(next_id).item())
|
||||
```
|
||||
This is a simplified version of GPT-2 inference for the demonstration only, The comprehensive solution on the GPT-2 model and its deviants are under development, and here is the [link](tutorials/gpt2_e2e.py) to the experimental.
|
||||
This is a simplified version of GPT-2 inference for the demonstration only, The comprehensive solution on the GPT-2 model and its deviants are under development, and here is the [link](tutorials/gpt2bs.py) to the experimental.
|
||||
|
||||
## Android/iOS
|
||||
The previous processing python code can be translated into all-in-one model to be run in Android/iOS mobile platform, without any Python runtime and the 3rd-party dependencies requirement. Here is the [tutorial](tutorials/ort_mobile.py)
|
||||
The previous processing python code can be translated into all-in-one model to be run in Android/iOS mobile platform, without any Python runtime and the 3rd-party dependencies requirement. Here is the [tutorial](tutorials/gpt2bs.py)
|
||||
|
||||
## CustomOp Conversion
|
||||
The mainstream ONNX converters support the custom op generation if there is the operation from the original framework cannot be interpreted as ONNX standard operators. Check the following two examples on how to do this.
|
||||
|
|
|
@ -89,6 +89,7 @@ class EagerOp:
|
|||
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||
idx += 1
|
||||
|
||||
# feed.update(kwargs)
|
||||
return feed
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
|
|
@ -1342,7 +1342,7 @@ class _ONNXOperatorAPI:
|
|||
|
||||
if repeats is None or (not isinstance(repeats, str) and all(repeat_count == 1 for repeat_count in repeats)):
|
||||
container.add_node('Identity', input_name, output_name, name=name)
|
||||
return
|
||||
return output_name
|
||||
|
||||
if container.target_opset < 6:
|
||||
intermediate_input_name = input_name
|
||||
|
|
|
@ -291,10 +291,10 @@ class ONNXTraceSession:
|
|||
if vi_.name in input_names:
|
||||
input_names[vi_.name] = vi_
|
||||
|
||||
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.t.size())
|
||||
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape())
|
||||
if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs]
|
||||
outputs = [helper.make_tensor_value_info(so.name, so.onnx_type,
|
||||
so.t.size()) for so in ts_outputs]
|
||||
so.get_shape()) for so in ts_outputs]
|
||||
|
||||
graph = helper.make_graph(nodes, graph_name, inputs,
|
||||
outputs, container.initializers)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import builtins
|
||||
import functools
|
||||
import numpy as np
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
@ -18,6 +19,7 @@ class _EagerTensor:
|
|||
name = name[0]
|
||||
self.name = '' if name is None else name
|
||||
self.raw_data = raw_data
|
||||
self.symbolic_shape = []
|
||||
|
||||
def __repr__(self):
|
||||
if self.raw_data is not None:
|
||||
|
@ -102,6 +104,9 @@ class _EagerTensor:
|
|||
def item(self):
|
||||
return self.numpy().item()
|
||||
|
||||
def get_shape(self):
|
||||
return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape
|
||||
|
||||
def _to_binary_tensor_args(self, other):
|
||||
# convert self, other to [self, other], but if either is a number, convert that to a constant
|
||||
x, y = self, other
|
||||
|
@ -294,7 +299,7 @@ class _EagerTensor:
|
|||
|
||||
@classmethod
|
||||
def ox_args(cls, tensors, output_names=None):
|
||||
input_names = [ts_.name for ts_ in tensors]
|
||||
input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors]
|
||||
return cls.ox_name_args(input_names, output_names)
|
||||
|
||||
def my_args(self):
|
||||
|
@ -338,6 +343,11 @@ class _EagerTensor:
|
|||
s = _ox.identity(*self.my_args())
|
||||
return self.create_and_verify(y, s[0])
|
||||
|
||||
def cpu(self):
|
||||
y = self._t.cpu()
|
||||
s = _ox.identity(*self.my_args())
|
||||
return self.create_and_verify(y, s[0])
|
||||
|
||||
def detach(self):
|
||||
y = self._t.detach()
|
||||
s = _ox.identity(*self.my_args())
|
||||
|
@ -366,13 +376,10 @@ class _EagerTensor:
|
|||
return self.create_and_verify(y, s[0])
|
||||
|
||||
|
||||
def _create_ox_sequence(*size, init_value=None, onnx_type=None):
|
||||
def _create_ox_sequence(*size):
|
||||
container = _EagerTensor.get_container()
|
||||
con_x = []
|
||||
if onnx_type is None:
|
||||
onnx_type = onnx_proto.TensorProto.FLOAT
|
||||
ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
|
||||
if any(isinstance(n_, _EagerTensor) for n_ in size):
|
||||
if builtins.any(isinstance(n_, _EagerTensor) for n_ in size):
|
||||
for x in size:
|
||||
if isinstance(x, _EagerTensor):
|
||||
x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0]
|
||||
|
@ -380,43 +387,71 @@ def _create_ox_sequence(*size, init_value=None, onnx_type=None):
|
|||
x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x])
|
||||
x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0]
|
||||
con_x.append(x_h)
|
||||
allnames = _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
|
||||
s = _ox.constant_of_shape(allnames, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
|
||||
return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
|
||||
else:
|
||||
ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size)
|
||||
shp_c = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
|
||||
s = _ox.constant_of_shape(shp_c, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
|
||||
return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
|
||||
|
||||
|
||||
def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None):
|
||||
if onnx_type is None:
|
||||
onnx_type = onnx_proto.TensorProto.FLOAT
|
||||
names = _create_ox_sequence(*size)
|
||||
ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
|
||||
|
||||
container = _EagerTensor.get_container()
|
||||
s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
|
||||
return s[0]
|
||||
|
||||
|
||||
def empty(*size: _int, memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
|
||||
def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
|
||||
dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None,
|
||||
requires_grad: _bool = False) -> _EagerTensor: # noqa
|
||||
|
||||
if len(size) == 1 and isinstance(size[0], list):
|
||||
size = size[0]
|
||||
n_size = _EagerTensor.normalize_seq(size)
|
||||
y = torch.empty(*n_size, memory_format=memory_format, out=out,
|
||||
dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
|
||||
s = _create_ox_sequence(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
return _EagerTensor.from_torch(y, s)
|
||||
|
||||
|
||||
def zeros(*size: _int, out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
||||
def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
||||
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
|
||||
|
||||
if len(size) == 1 and isinstance(size[0], list):
|
||||
size = size[0]
|
||||
n_size = _EagerTensor.normalize_seq(size)
|
||||
y = torch.zeros(*n_size, out=out, dtype=dtype,
|
||||
layout=layout, device=device, requires_grad=requires_grad)
|
||||
s = _create_ox_sequence(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
return _EagerTensor.from_torch(y, s)
|
||||
|
||||
|
||||
def ones(*size: _int, out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
||||
def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
||||
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
|
||||
|
||||
if len(size) == 1 and isinstance(size[0], list):
|
||||
size = size[0]
|
||||
n_size = _EagerTensor.normalize_seq(size)
|
||||
y = torch.ones(*n_size, out=out, dtype=dtype,
|
||||
layout=layout, device=device, requires_grad=requires_grad)
|
||||
s = _create_ox_sequence(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
||||
return _EagerTensor.from_torch(y, s)
|
||||
|
||||
|
||||
def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa
|
||||
|
||||
if len(repeats) == 1 and isinstance(repeats[0], list):
|
||||
repeats = repeats[0]
|
||||
n_size = _EagerTensor.normalize_seq(repeats)
|
||||
y = input_ts.t.repeat(*n_size)
|
||||
seq = _create_ox_sequence(*repeats)
|
||||
s = _ox.tile(*input_ts.my_args(), repeats=seq[0])
|
||||
return _EagerTensor.from_torch(y, s[0])
|
||||
|
||||
|
||||
def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa
|
||||
y = torch.argmax(input_ts.value, dim, keepdim)
|
||||
s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim)
|
||||
|
@ -442,7 +477,18 @@ def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTenso
|
|||
container = _EagerTensor.get_container()
|
||||
y = torch.all(input_ts.value)
|
||||
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
|
||||
s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[0])
|
||||
s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1])
|
||||
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
|
||||
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
|
||||
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
|
||||
return input_ts.create_and_verify(y, s[0])
|
||||
|
||||
|
||||
def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
|
||||
container = _EagerTensor.get_container()
|
||||
y = torch.any(input_ts.value)
|
||||
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
|
||||
s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1])
|
||||
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
|
||||
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
|
||||
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
|
||||
|
@ -489,7 +535,7 @@ class _ControlFlowContext:
|
|||
self.sub_graph = None
|
||||
|
||||
def flow_output(self, cond, *outputs):
|
||||
assert len(outputs) > len(self.loop_states), "The loop body doesn't return enough objects"
|
||||
assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects"
|
||||
if self.sub_graph is None:
|
||||
trc = _EagerTensor.get_trace_session()
|
||||
self.sub_graph = trc.build_graph(trc.container,
|
||||
|
@ -567,7 +613,10 @@ def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp:
|
|||
_EagerTensor._all_ops = {'argmax': argmax,
|
||||
'softmax': softmax,
|
||||
'reshape': reshape,
|
||||
'transpose': transpose}
|
||||
'transpose': transpose,
|
||||
'repeat': repeat,
|
||||
'any': any,
|
||||
'all': all}
|
||||
|
||||
tensor = _EagerTensor.mytensor
|
||||
tensor_from_onnx = _EagerTensor.from_onnx
|
||||
|
|
|
@ -115,7 +115,8 @@ if not os.path.exists(gpt2_core_model_path) or \
|
|||
output_ms = inference_and_dump_full_model(input_text)
|
||||
|
||||
# 3. Inference on the all-in-one model
|
||||
full_model = eager_op.EagerOp.from_model(gpt2_full_model_path)
|
||||
from onnxruntime_extensions import PyOrtFunction
|
||||
full_model = PyOrtFunction.from_model(gpt2_full_model_path)
|
||||
output_text = full_model(input_text, num_tokens_to_produce)
|
||||
|
||||
# 4. Test the result
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
from transformers import AutoConfig
|
||||
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model
|
||||
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
|
||||
|
||||
# Create a cache directory to store pretrained model.
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
|
||||
model_name_or_path = "gpt2"
|
||||
device = "cpu"
|
||||
beam_width = 4
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
num_attention_heads = config.n_head
|
||||
hidden_size = config.n_embd
|
||||
num_layer = config.n_layer
|
||||
gpt2_full_model_path = "./gpt2_full.onnx"
|
||||
|
||||
# this model was generated by this script
|
||||
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2-OneStepSearch_OnnxRuntime_CPU.ipynb
|
||||
onnx_model_path = "gpt2_one_step_search.onnx"
|
||||
func_one_step = pyfunc_from_model(onnx_model_path)
|
||||
|
||||
|
||||
def get_tokenizer(model_name_or_path, cache_dir):
|
||||
from transformers import GPT2Tokenizer # noqa
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
gpt2_encoder_model_path = './gpt2_tok.onnx'
|
||||
build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer)
|
||||
return tokenizer, pyfunc_from_model(gpt2_encoder_model_path)
|
||||
|
||||
|
||||
def inference_and_dump_full_model(tokenizer, func_tokenizer, input_text, num_tokens_to_produce = 30):
|
||||
|
||||
with trace_for_onnx(input_text, num_tokens_to_produce, names=func_tokenizer.input_names) as tc_sess:
|
||||
inputs, num_tokens = tc_sess.get_inputs()
|
||||
input_ids, attention_mask = func_tokenizer(inputs, padding=True, padding_side='left')
|
||||
attention_mask = attention_mask.type(torch.float)
|
||||
position_ids = (attention_mask.long().cumsum(-1) - 1)
|
||||
# position_ids.masked_fill_(position_ids < 0, 0)
|
||||
|
||||
# Empty Past State for generating first word
|
||||
# batch_size = input_ids.size()[0]
|
||||
batch_size = 1
|
||||
past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
|
||||
empty_past = []
|
||||
for _ in range(num_layer):
|
||||
empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device))
|
||||
|
||||
beam_select_idx = torch.zeros([1, batch_size]).long()
|
||||
input_log_probs = torch.zeros([batch_size, 1])
|
||||
input_unfinished_sents = torch.ones([batch_size, 1], dtype=torch.bool)
|
||||
prev_step_scores = torch.zeros([batch_size, 1])
|
||||
beam_size = beam_width
|
||||
prev_step_results = input_ids.clone().detach().to(device)
|
||||
|
||||
cfg = torch.control_flow()
|
||||
for states in cfg.loop(num_tokens, torch.tensor(True), input_ids, position_ids,
|
||||
attention_mask, beam_select_idx, input_log_probs,
|
||||
input_unfinished_sents, prev_step_results, prev_step_scores, *empty_past):
|
||||
step = states[0]
|
||||
states[1].symbolic_shape = ['batch_size', 'seq_len']
|
||||
states[2].symbolic_shape = ['batch_size', 'seq_len']
|
||||
states[3].symbolic_shape = ['batch_size', 'all_seq_len']
|
||||
states[4].symbolic_shape = [1, 'batch_size']
|
||||
|
||||
# prev_step_results
|
||||
states[7].symbolic_shape = ['batch_size', 'total_seq_len']
|
||||
|
||||
for st_ in states[-num_layer:]:
|
||||
st_.symbolic_shape = [2, 'batch_size', num_attention_heads, 'past_seq_len', hidden_size // num_attention_heads]
|
||||
|
||||
prev_attention_mask = states[3]
|
||||
outputs = func_one_step(*states[1:])
|
||||
last_state = outputs[0].clone().detach().cpu()
|
||||
input_ids = last_state.reshape([batch_size * beam_size, -1]).to(device)
|
||||
|
||||
input_unfinished_sents_id = -3
|
||||
prev_step_results = outputs[-2].clone().detach().to(device)
|
||||
# position_ids = (torch.tensor([context_length + step - 1
|
||||
# ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device))
|
||||
position_ids = torch.zeros([batch_size * beam_size, 1], dtype=torch.int64) + attention_mask.size()[-1]
|
||||
factor = (~step.type(torch.bool)).type(torch.int64)
|
||||
prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
prev_attention_mask,
|
||||
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
|
||||
],
|
||||
1,
|
||||
).to(device)
|
||||
|
||||
beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
|
||||
input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)
|
||||
input_unfinished_sents = outputs[input_unfinished_sents_id].clone().detach().to(device)
|
||||
prev_step_scores = outputs[-1].clone().detach().to(device)
|
||||
|
||||
past = []
|
||||
for i in range(num_layer):
|
||||
past_i = outputs[i + 1].clone().detach()
|
||||
past.append(past_i.to(device))
|
||||
|
||||
|
||||
any_unfinished = input_unfinished_sents.any()
|
||||
input_ids.symbolic_shape = ['total_batch_size', 'seq_len']
|
||||
position_ids.symbolic_shape = ['total_batch_size', 'seq_len']
|
||||
attention_mask.symbolic_shape = ['total_batch_size', 'all_seq_len']
|
||||
prev_step_results.symbolic_shape = ['total_batch_size', 'step_seq_len']
|
||||
for st_ in past:
|
||||
st_.symbolic_shape = [2, 'total_batch_size', num_attention_heads, 'all_seq_len', hidden_size // num_attention_heads]
|
||||
cfg.flow_output(any_unfinished, input_ids,
|
||||
position_ids, attention_mask, beam_select_idx,
|
||||
input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past)
|
||||
|
||||
result_id = 6
|
||||
all_token_ids = cfg.finalize()[result_id]
|
||||
tc_sess.save_as_onnx(gpt2_full_model_path, all_token_ids)
|
||||
|
||||
print(tokenizer.decode(all_token_ids.t[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
def verify_bsfull_model(input_text):
|
||||
from onnxruntime_extensions import PyOrtFunction
|
||||
gpt2_all = PyOrtFunction.from_model(gpt2_full_model_path)
|
||||
outputs = gpt2_all(input_text, 30)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer, func_tokenizer = get_tokenizer(model_name_or_path, cache_dir)
|
||||
input_text = ['best hotel in bay area.']
|
||||
inference_and_dump_full_model(tokenizer, func_tokenizer, input_text)
|
||||
verify_bsfull_model(input_text)
|
Загрузка…
Ссылка в новой задаче