the full beam search post processing for GPT-2 model (#94)

* the full beam search post processing for GPT-2 model

* minor fixes
This commit is contained in:
Wenbing Li 2021-05-24 23:05:34 -07:00 коммит произвёл GitHub
Родитель 06c902253f
Коммит 352b2003bc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 214 добавлений и 25 удалений

Просмотреть файл

@ -23,10 +23,10 @@ output, *_ = gpt2_core(input_ids)
next_id = numpy.argmax(output[:, :, -1, :], axis=-1)
print(input_text[0] + decode(next_id).item())
```
This is a simplified version of GPT-2 inference for the demonstration only, The comprehensive solution on the GPT-2 model and its deviants are under development, and here is the [link](tutorials/gpt2_e2e.py) to the experimental.
This is a simplified version of GPT-2 inference for the demonstration only, The comprehensive solution on the GPT-2 model and its deviants are under development, and here is the [link](tutorials/gpt2bs.py) to the experimental.
## Android/iOS
The previous processing python code can be translated into all-in-one model to be run in Android/iOS mobile platform, without any Python runtime and the 3rd-party dependencies requirement. Here is the [tutorial](tutorials/ort_mobile.py)
The previous processing python code can be translated into all-in-one model to be run in Android/iOS mobile platform, without any Python runtime and the 3rd-party dependencies requirement. Here is the [tutorial](tutorials/gpt2bs.py)
## CustomOp Conversion
The mainstream ONNX converters support the custom op generation if there is the operation from the original framework cannot be interpreted as ONNX standard operators. Check the following two examples on how to do this.

Просмотреть файл

@ -89,6 +89,7 @@ class EagerOp:
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
idx += 1
# feed.update(kwargs)
return feed
def __call__(self, *args, **kwargs):

Просмотреть файл

@ -1342,7 +1342,7 @@ class _ONNXOperatorAPI:
if repeats is None or (not isinstance(repeats, str) and all(repeat_count == 1 for repeat_count in repeats)):
container.add_node('Identity', input_name, output_name, name=name)
return
return output_name
if container.target_opset < 6:
intermediate_input_name = input_name

Просмотреть файл

@ -291,10 +291,10 @@ class ONNXTraceSession:
if vi_.name in input_names:
input_names[vi_.name] = vi_
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.t.size())
inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape())
if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs]
outputs = [helper.make_tensor_value_info(so.name, so.onnx_type,
so.t.size()) for so in ts_outputs]
so.get_shape()) for so in ts_outputs]
graph = helper.make_graph(nodes, graph_name, inputs,
outputs, container.initializers)

Просмотреть файл

@ -1,4 +1,5 @@
import torch
import builtins
import functools
import numpy as np
from onnx import onnx_pb as onnx_proto
@ -18,6 +19,7 @@ class _EagerTensor:
name = name[0]
self.name = '' if name is None else name
self.raw_data = raw_data
self.symbolic_shape = []
def __repr__(self):
if self.raw_data is not None:
@ -102,6 +104,9 @@ class _EagerTensor:
def item(self):
return self.numpy().item()
def get_shape(self):
return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape
def _to_binary_tensor_args(self, other):
# convert self, other to [self, other], but if either is a number, convert that to a constant
x, y = self, other
@ -294,7 +299,7 @@ class _EagerTensor:
@classmethod
def ox_args(cls, tensors, output_names=None):
input_names = [ts_.name for ts_ in tensors]
input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors]
return cls.ox_name_args(input_names, output_names)
def my_args(self):
@ -338,6 +343,11 @@ class _EagerTensor:
s = _ox.identity(*self.my_args())
return self.create_and_verify(y, s[0])
def cpu(self):
y = self._t.cpu()
s = _ox.identity(*self.my_args())
return self.create_and_verify(y, s[0])
def detach(self):
y = self._t.detach()
s = _ox.identity(*self.my_args())
@ -366,13 +376,10 @@ class _EagerTensor:
return self.create_and_verify(y, s[0])
def _create_ox_sequence(*size, init_value=None, onnx_type=None):
def _create_ox_sequence(*size):
container = _EagerTensor.get_container()
con_x = []
if onnx_type is None:
onnx_type = onnx_proto.TensorProto.FLOAT
ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
if any(isinstance(n_, _EagerTensor) for n_ in size):
if builtins.any(isinstance(n_, _EagerTensor) for n_ in size):
for x in size:
if isinstance(x, _EagerTensor):
x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0]
@ -380,43 +387,71 @@ def _create_ox_sequence(*size, init_value=None, onnx_type=None):
x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x])
x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0]
con_x.append(x_h)
allnames = _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
s = _ox.constant_of_shape(allnames, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
else:
ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size)
shp_c = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
s = _ox.constant_of_shape(shp_c, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None):
if onnx_type is None:
onnx_type = onnx_proto.TensorProto.FLOAT
names = _create_ox_sequence(*size)
ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
container = _EagerTensor.get_container()
s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
return s[0]
def empty(*size: _int, memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None,
requires_grad: _bool = False) -> _EagerTensor: # noqa
if len(size) == 1 and isinstance(size[0], list):
size = size[0]
n_size = _EagerTensor.normalize_seq(size)
y = torch.empty(*n_size, memory_format=memory_format, out=out,
dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
s = _create_ox_sequence(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
return _EagerTensor.from_torch(y, s)
def zeros(*size: _int, out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
if len(size) == 1 and isinstance(size[0], list):
size = size[0]
n_size = _EagerTensor.normalize_seq(size)
y = torch.zeros(*n_size, out=out, dtype=dtype,
layout=layout, device=device, requires_grad=requires_grad)
s = _create_ox_sequence(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
return _EagerTensor.from_torch(y, s)
def ones(*size: _int, out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
if len(size) == 1 and isinstance(size[0], list):
size = size[0]
n_size = _EagerTensor.normalize_seq(size)
y = torch.ones(*n_size, out=out, dtype=dtype,
layout=layout, device=device, requires_grad=requires_grad)
s = _create_ox_sequence(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
return _EagerTensor.from_torch(y, s)
def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa
if len(repeats) == 1 and isinstance(repeats[0], list):
repeats = repeats[0]
n_size = _EagerTensor.normalize_seq(repeats)
y = input_ts.t.repeat(*n_size)
seq = _create_ox_sequence(*repeats)
s = _ox.tile(*input_ts.my_args(), repeats=seq[0])
return _EagerTensor.from_torch(y, s[0])
def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa
y = torch.argmax(input_ts.value, dim, keepdim)
s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim)
@ -442,7 +477,18 @@ def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTenso
container = _EagerTensor.get_container()
y = torch.all(input_ts.value)
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[0])
s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1])
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
return input_ts.create_and_verify(y, s[0])
def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
container = _EagerTensor.get_container()
y = torch.any(input_ts.value)
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1])
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
@ -489,7 +535,7 @@ class _ControlFlowContext:
self.sub_graph = None
def flow_output(self, cond, *outputs):
assert len(outputs) > len(self.loop_states), "The loop body doesn't return enough objects"
assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects"
if self.sub_graph is None:
trc = _EagerTensor.get_trace_session()
self.sub_graph = trc.build_graph(trc.container,
@ -567,7 +613,10 @@ def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp:
_EagerTensor._all_ops = {'argmax': argmax,
'softmax': softmax,
'reshape': reshape,
'transpose': transpose}
'transpose': transpose,
'repeat': repeat,
'any': any,
'all': all}
tensor = _EagerTensor.mytensor
tensor_from_onnx = _EagerTensor.from_onnx

Просмотреть файл

@ -115,7 +115,8 @@ if not os.path.exists(gpt2_core_model_path) or \
output_ms = inference_and_dump_full_model(input_text)
# 3. Inference on the all-in-one model
full_model = eager_op.EagerOp.from_model(gpt2_full_model_path)
from onnxruntime_extensions import PyOrtFunction
full_model = PyOrtFunction.from_model(gpt2_full_model_path)
output_text = full_model(input_text, num_tokens_to_produce)
# 4. Test the result

138
tutorials/gpt2bs.py Normal file
Просмотреть файл

@ -0,0 +1,138 @@
import os
from transformers import AutoConfig
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
# Create a cache directory to store pretrained model.
cache_dir = os.path.join(".", "cache_models")
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
model_name_or_path = "gpt2"
device = "cpu"
beam_width = 4
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
num_attention_heads = config.n_head
hidden_size = config.n_embd
num_layer = config.n_layer
gpt2_full_model_path = "./gpt2_full.onnx"
# this model was generated by this script
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2-OneStepSearch_OnnxRuntime_CPU.ipynb
onnx_model_path = "gpt2_one_step_search.onnx"
func_one_step = pyfunc_from_model(onnx_model_path)
def get_tokenizer(model_name_or_path, cache_dir):
from transformers import GPT2Tokenizer # noqa
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
gpt2_encoder_model_path = './gpt2_tok.onnx'
build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer)
return tokenizer, pyfunc_from_model(gpt2_encoder_model_path)
def inference_and_dump_full_model(tokenizer, func_tokenizer, input_text, num_tokens_to_produce = 30):
with trace_for_onnx(input_text, num_tokens_to_produce, names=func_tokenizer.input_names) as tc_sess:
inputs, num_tokens = tc_sess.get_inputs()
input_ids, attention_mask = func_tokenizer(inputs, padding=True, padding_side='left')
attention_mask = attention_mask.type(torch.float)
position_ids = (attention_mask.long().cumsum(-1) - 1)
# position_ids.masked_fill_(position_ids < 0, 0)
# Empty Past State for generating first word
# batch_size = input_ids.size()[0]
batch_size = 1
past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
empty_past = []
for _ in range(num_layer):
empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device))
beam_select_idx = torch.zeros([1, batch_size]).long()
input_log_probs = torch.zeros([batch_size, 1])
input_unfinished_sents = torch.ones([batch_size, 1], dtype=torch.bool)
prev_step_scores = torch.zeros([batch_size, 1])
beam_size = beam_width
prev_step_results = input_ids.clone().detach().to(device)
cfg = torch.control_flow()
for states in cfg.loop(num_tokens, torch.tensor(True), input_ids, position_ids,
attention_mask, beam_select_idx, input_log_probs,
input_unfinished_sents, prev_step_results, prev_step_scores, *empty_past):
step = states[0]
states[1].symbolic_shape = ['batch_size', 'seq_len']
states[2].symbolic_shape = ['batch_size', 'seq_len']
states[3].symbolic_shape = ['batch_size', 'all_seq_len']
states[4].symbolic_shape = [1, 'batch_size']
# prev_step_results
states[7].symbolic_shape = ['batch_size', 'total_seq_len']
for st_ in states[-num_layer:]:
st_.symbolic_shape = [2, 'batch_size', num_attention_heads, 'past_seq_len', hidden_size // num_attention_heads]
prev_attention_mask = states[3]
outputs = func_one_step(*states[1:])
last_state = outputs[0].clone().detach().cpu()
input_ids = last_state.reshape([batch_size * beam_size, -1]).to(device)
input_unfinished_sents_id = -3
prev_step_results = outputs[-2].clone().detach().to(device)
# position_ids = (torch.tensor([context_length + step - 1
# ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device))
position_ids = torch.zeros([batch_size * beam_size, 1], dtype=torch.int64) + attention_mask.size()[-1]
factor = (~step.type(torch.bool)).type(torch.int64)
prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
attention_mask = torch.cat(
[
prev_attention_mask,
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
],
1,
).to(device)
beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)
input_unfinished_sents = outputs[input_unfinished_sents_id].clone().detach().to(device)
prev_step_scores = outputs[-1].clone().detach().to(device)
past = []
for i in range(num_layer):
past_i = outputs[i + 1].clone().detach()
past.append(past_i.to(device))
any_unfinished = input_unfinished_sents.any()
input_ids.symbolic_shape = ['total_batch_size', 'seq_len']
position_ids.symbolic_shape = ['total_batch_size', 'seq_len']
attention_mask.symbolic_shape = ['total_batch_size', 'all_seq_len']
prev_step_results.symbolic_shape = ['total_batch_size', 'step_seq_len']
for st_ in past:
st_.symbolic_shape = [2, 'total_batch_size', num_attention_heads, 'all_seq_len', hidden_size // num_attention_heads]
cfg.flow_output(any_unfinished, input_ids,
position_ids, attention_mask, beam_select_idx,
input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past)
result_id = 6
all_token_ids = cfg.finalize()[result_id]
tc_sess.save_as_onnx(gpt2_full_model_path, all_token_ids)
print(tokenizer.decode(all_token_ids.t[0], skip_special_tokens=True))
def verify_bsfull_model(input_text):
from onnxruntime_extensions import PyOrtFunction
gpt2_all = PyOrtFunction.from_model(gpt2_full_model_path)
outputs = gpt2_all(input_text, 30)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__":
tokenizer, func_tokenizer = get_tokenizer(model_name_or_path, cache_dir)
input_text = ['best hotel in bay area.']
inference_and_dump_full_model(tokenizer, func_tokenizer, input_text)
verify_bsfull_model(input_text)