ort-customops/test/test_processing.py

144 строки
5.4 KiB
Python

import onnx
import numpy
import torch
import unittest
from typing import List, Tuple
from PIL import Image
from distutils.version import LooseVersion
from onnxruntime_extensions import OrtPyFunction
from onnxruntime_extensions import pnp, util
from transformers import GPT2Config, GPT2LMHeadModel
class _GPT2LMHeadModel(GPT2LMHeadModel):
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state."""
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids, attention_mask):
result = super(_GPT2LMHeadModel, self).forward(
input_ids, attention_mask=attention_mask, return_dict=False
)
# drop the past states
return result[0]
@torch.jit.script
def _broadcasting_add(input_list: List[torch.Tensor]) -> torch.Tensor:
return input_list[1] + input_list[0]
class _SequenceTensorModel(pnp.ProcessingScriptModule):
def forward(self, img_list: List[torch.Tensor]) -> torch.Tensor:
return _broadcasting_add(img_list)
class _MobileNetProcessingModule(pnp.ProcessingScriptModule):
def __init__(self, oxml):
super(_MobileNetProcessingModule, self).__init__()
self.model_function_id = pnp.create_model_function(oxml)
self.pre_proc = torch.jit.trace(
pnp.PreMobileNet(224), torch.zeros(224, 224, 3, dtype=torch.float32)
)
self.post_proc = torch.jit.trace(
pnp.ImageNetPostProcessing(), torch.zeros(1, 1000, dtype=torch.float32)
)
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
proc_input = self.pre_proc(img)
return self.post_proc.forward(
pnp.invoke_onnx_model1(self.model_function_id, proc_input)
)
@unittest.skipIf(
LooseVersion(torch.__version__) < LooseVersion("1.9"),
"Not works with older PyTorch",
)
class TestPreprocessing(unittest.TestCase):
def test_imagenet_preprocessing(self):
mnv2 = onnx.load_model(util.get_test_data_file("data", "mobilev2.onnx"))
# load an image
img = Image.open(util.get_test_data_file("data", "pineapple.jpg"))
img = torch.from_numpy(numpy.asarray(img.convert("RGB")))
full_models = pnp.SequentialProcessingModule(
pnp.PreMobileNet(224), mnv2, pnp.PostMobileNet()
)
ids, probabilities = full_models.forward(img)
name_i = "image"
full_model_func = OrtPyFunction.from_model(
pnp.export(
full_models,
img,
opset_version=11,
output_path="temp_imagenet.onnx",
input_names=[name_i],
dynamic_axes={name_i: [0, 1]},
)
)
actual_ids, actual_result = full_model_func(img.numpy())
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
self.assertEqual(
ids[0, 0].item(), 953
) # 953 is pineapple class id in the imagenet dataset
def test_gpt2_preprocessing(self):
cfg = GPT2Config(n_layer=3)
gpt2_m = _GPT2LMHeadModel(cfg)
gpt2_m.eval().to("cpu")
test_sentence = ["Test a sentence"]
tok = pnp.PreHuggingFaceGPT2(
vocab_file=util.get_test_data_file("data", "gpt2.vocab"),
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
)
inputs = tok.forward(test_sentence)
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
with open("temp_gpt2lmh.onnx", "wb") as f:
torch.onnx.export(
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
)
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
expected = full_model.forward(test_sentence)
model = pnp.export(
full_model, test_sentence, opset_version=12, do_constant_folding=False
)
mfunc = OrtPyFunction.from_model(model)
actuals = mfunc(test_sentence)
# the random weight may generate a large diff in result, test the shape only.
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
def test_sequence_tensor(self):
seq_m = _SequenceTensorModel()
test_input = [
torch.from_numpy(_i)
for _i in [
numpy.array([1]).astype(numpy.int64),
numpy.array([3, 4]).astype(numpy.int64),
numpy.array([5, 6]).astype(numpy.int64),
]
]
res = seq_m.forward(test_input)
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
# The fixing for the sequence tensor support is only released in 1.11 and the above.
oxml = pnp.export(
seq_m, test_input, opset_version=12, output_path="temp_seqtest.onnx"
)
# TODO: ORT doesn't accept the default empty element type of a sequence type.
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT64, [])
)
mfunc = OrtPyFunction.from_model(oxml)
o_res = mfunc([_i.numpy() for _i in test_input])
numpy.testing.assert_allclose(res, o_res)
if __name__ == "__main__":
unittest.main()