PyOp attribute supports int and float data type (#425)
This commit is contained in:
Родитель
9cbd2ada18
Коммит
46efcb9051
|
@ -77,7 +77,7 @@ The PyTorch and TensorFlow converters support custom operator generation if the
|
|||
|
||||
## Add a new custom operator to onnxruntime-extensions
|
||||
|
||||
You can contribute customop C++ implementations directly in this repository if they have general applicability to other users. In addition, if you want to quickly verify the ONNX model with Python, you can wrap the custom operator with PyOp.
|
||||
You can contribute customop C++ implementations directly in this repository if they have general applicability to other users. In addition, if you want to quickly verify the ONNX model with Python, you can wrap the custom operator with **[PyOp](docs/pyop.md)**.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# PyOp
|
||||
|
||||
Custom operators are a powerful feature in ONNX Runtime that allows users to extend the functionality of the runtime by implementing their own operators to perform specific operations not available in the standard ONNX operator set.
|
||||
|
||||
In this document, we will introduce how to create a custom operator using Python functions and integrate it into ONNX Runtime for inference.
|
||||
|
||||
|
||||
## Step 1: Define the Python function for the custom operator
|
||||
Start by defining the Python function that will serve as the implementation for your custom operator. Ensure that the function is compatible with the input and output tensor shapes you expect for your custom operator.
|
||||
the Python decorator @onnx_op will convert the function to be a custom operator implementation. The following is example we create a function for a tokenizer
|
||||
|
||||
```Python
|
||||
@onnx_op(op_type="GPT2Tokenizer",
|
||||
inputs=[PyCustomOpDef.dt_string],
|
||||
outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_int64],
|
||||
attrs={"padding_length": PyCustomOpDef.dt_int64})
|
||||
def bpe_tokenizer(s, **kwargs):
|
||||
padding_length = kwargs["padding_length"]
|
||||
input_ids, attention_mask = cls.tokenizer.tokenizer_sentence([s[0]], padding_length)
|
||||
return input_ids, attention_mask
|
||||
```
|
||||
Because ONNXRuntimme needs the custom operator schema on loading a model, please specify them by onnx_op arguments. Also 'attrs' is needed if there are attributes for the ONNX node, which can be dict that mapping from its name to its type, or be a list if all types are string only.
|
||||
|
||||
## Step 2: Create an ONNX model with the custom operator
|
||||
Now that the custom operator is registered with ONNX Runtime, you can create an ONNX model that utilizes it. You can either modify an existing ONNX model to include the custom operator or create a new one from scratch.
|
||||
|
||||
To create a new ONNX model with the custom operator, you can use the ONNX Python API. Here is an example:[test_pyops.py](../test/test_pyops.py)
|
|
@ -60,7 +60,9 @@ class Opdef:
|
|||
opdef._nativedef.output_types = outputs
|
||||
attrs = kwargs.get('attrs', None)
|
||||
if attrs is None:
|
||||
attrs = []
|
||||
attrs = {}
|
||||
elif isinstance(attrs, (list, tuple)):
|
||||
attrs = {k: PyCustomOpDef.dt_string for k in attrs}
|
||||
opdef._nativedef.attrs = attrs
|
||||
add_custom_op(opdef._nativedef)
|
||||
return opdef
|
||||
|
@ -68,6 +70,20 @@ class Opdef:
|
|||
def __call__(self, *args, **kwargs):
|
||||
return self.body(*args, **kwargs)
|
||||
|
||||
def cast_attributes(self, attributes):
|
||||
res = {}
|
||||
for k, v in attributes.items():
|
||||
if self._nativedef.attrs[k] == PyCustomOpDef.dt_int64:
|
||||
res[k] = int(v)
|
||||
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_float:
|
||||
res[k] = float(v)
|
||||
elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
|
||||
res[k] = v
|
||||
else:
|
||||
raise RuntimeError("Unsupported attribute type {}.".format(
|
||||
self._nativedef.attrs[k]))
|
||||
return res
|
||||
|
||||
|
||||
def _on_pyop_invocation(k_id, feed, attributes):
|
||||
if k_id not in Opdef._odlist:
|
||||
|
@ -75,7 +91,7 @@ def _on_pyop_invocation(k_id, feed, attributes):
|
|||
"Unable to find function id={}. "
|
||||
"Did you decorate the operator with @onnx_op?.".format(k_id))
|
||||
op_ = Opdef._odlist[k_id]
|
||||
rv = op_.body(*feed, **attributes)
|
||||
rv = op_.body(*feed, **op_.cast_attributes(attributes))
|
||||
if isinstance(rv, tuple):
|
||||
# Multiple outputs.
|
||||
res = []
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "string_tensor.h"
|
||||
#include "pykernel.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
const int PyCustomOpDef::undefined = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
|
@ -211,37 +210,62 @@ typedef struct {
|
|||
} InputInformation;
|
||||
|
||||
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info,
|
||||
uint64_t id, const std::vector<std::string>& attrs)
|
||||
uint64_t id, const std::map<std::string, int>& attrs)
|
||||
: api_(api),
|
||||
ort_(api_),
|
||||
obj_id_(id) {
|
||||
size_t size;
|
||||
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), nullptr, &size);
|
||||
for (std::map<std::string, int>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
std::string attr_name = it->first;
|
||||
int attr_type = it->second;
|
||||
OrtStatus* status = nullptr;
|
||||
std::string attr_value;
|
||||
if (attr_type == PyCustomOpDef::dt_int64) {
|
||||
int64_t value = 0;
|
||||
status = api_.KernelInfoGetAttribute_int64(&info, attr_name.c_str(), &value);
|
||||
if (status == nullptr) {
|
||||
std::stringstream ss;
|
||||
ss << value;
|
||||
attr_value = ss.str();
|
||||
}
|
||||
} else if (attr_type == PyCustomOpDef::dt_float) {
|
||||
float value = 0.f;
|
||||
status = api_.KernelInfoGetAttribute_float(&info, attr_name.c_str(), &value);
|
||||
if (status == nullptr) {
|
||||
std::stringstream ss;
|
||||
ss << value;
|
||||
attr_value = ss.str();
|
||||
}
|
||||
} else if (attr_type == PyCustomOpDef::dt_string) {
|
||||
size_t size = 0;
|
||||
status = api_.KernelInfoGetAttribute_string(&info, attr_name.c_str(), nullptr, &size);
|
||||
if (status == nullptr || api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) {
|
||||
attr_value = std::string(size, ' ');
|
||||
status = api_.KernelInfoGetAttribute_string(&info, attr_name.c_str(), attr_value.data(), &size);
|
||||
if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
|
||||
api_.ReleaseStatus(status);
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to retrieve attribute '", attr_name, "' due to '",
|
||||
api_.GetErrorMessage(status), "'."));
|
||||
}
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
}
|
||||
attr_value.resize(size - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
|
||||
std::string error_message(api_.GetErrorMessage(status));
|
||||
api_.ReleaseStatus(status);
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to find attribute '", *it, "' due to '",
|
||||
"Unable to find attribute '", attr_name, "' due to '",
|
||||
error_message, "'."));
|
||||
}
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
}
|
||||
attrs_values_[*it] = "";
|
||||
attrs_values_[*it].resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), &(attrs_values_[*it][0]), &size);
|
||||
if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
|
||||
api_.ReleaseStatus(status);
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to retrieve attribute '", *it, "' due to '",
|
||||
api_.GetErrorMessage(status), "'."));
|
||||
}
|
||||
attrs_values_[*it].resize(size - 1);
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
attrs_values_[attr_name] = attr_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -358,7 +382,6 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
std::map<std::string, std::vector<PyCustomOpFactory>>& PyOp_container() {
|
||||
static std::map<std::string, std::vector<PyCustomOpFactory>> map_custom_opdef;
|
||||
return map_custom_opdef;
|
||||
|
@ -376,7 +399,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
|||
|
||||
// No need to protect against concurrent access, GIL is doing that.
|
||||
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
|
||||
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
||||
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
||||
assert(success || !it_domain_op->second.empty());
|
||||
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
|
||||
}
|
||||
|
@ -391,7 +414,8 @@ const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t num) {
|
|||
if (it != PyOp_container().end()) {
|
||||
const std::vector<PyCustomOpFactory>& ref = it->second;
|
||||
if (num < ref.size()) {
|
||||
return ref.data() + num; }
|
||||
return ref.data() + num;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -399,7 +423,7 @@ const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t num) {
|
|||
|
||||
const OrtCustomOp* FetchPyCustomOps(size_t& num) {
|
||||
auto ptr = PyCustomOpDef_FetchPyCustomOps(num);
|
||||
if (ptr == nullptr) // For the breakpoint in debugging.
|
||||
if (ptr == nullptr) // For the breakpoint in debugging.
|
||||
return nullptr;
|
||||
return ptr;
|
||||
}
|
||||
|
@ -411,20 +435,20 @@ bool EnablePyCustomOps(bool enabled) {
|
|||
return last;
|
||||
}
|
||||
|
||||
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions* options, const OrtApi* ortApi){
|
||||
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions* options, const OrtApi* ortApi) {
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
OrtStatus* status = nullptr;
|
||||
|
||||
for (auto const& val_pair: PyOp_container()) {
|
||||
for (auto const& val_pair : PyOp_container()) {
|
||||
if (val_pair.first == c_OpDomain) {
|
||||
continue; // Register this domain in the second iteration.
|
||||
continue; // Register this domain in the second iteration.
|
||||
}
|
||||
|
||||
if (status = ortApi->CreateCustomOpDomain(val_pair.first.c_str(), &domain); status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
for (auto const& cop: val_pair.second) {
|
||||
for (auto const& cop : val_pair.second) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, &cop); status) {
|
||||
return status;
|
||||
}
|
||||
|
@ -453,8 +477,10 @@ uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
|
|||
void AddGlobalMethods(pybind11::module& m) {
|
||||
m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
|
||||
m.def("enable_py_op", &EnablePyCustomOps, "Enable or disable pyop functions.");
|
||||
m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); }, "Add a PyOp Python object.");
|
||||
m.def("default_opset_domain", []{return std::string(c_OpDomain);}, "return the default opset domain name.");
|
||||
m.def(
|
||||
"add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); }, "Add a PyOp Python object.");
|
||||
m.def(
|
||||
"default_opset_domain", [] { return std::string(c_OpDomain); }, "return the default opset domain name.");
|
||||
}
|
||||
|
||||
void AddObjectMethods(pybind11::module& m) {
|
||||
|
@ -465,8 +491,7 @@ void AddObjectMethods(pybind11::module& m) {
|
|||
.def_readwrite("input_types", &PyCustomOpDef::input_types)
|
||||
.def_readwrite("output_types", &PyCustomOpDef::output_types)
|
||||
.def_readwrite("attrs", &PyCustomOpDef::attrs)
|
||||
.def_static("install_hooker", [](py::object obj) {
|
||||
PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
|
||||
.def_static("install_hooker", [](py::object obj) { PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
|
||||
.def_readonly_static("undefined", &PyCustomOpDef::undefined)
|
||||
.def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
|
||||
.def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)
|
||||
|
@ -494,6 +519,6 @@ PYBIND11_MODULE(_extensions_pydll, m) {
|
|||
AddObjectMethods(m);
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() {
|
||||
PyCustomOpDefImpl::op_invoker.reset();
|
||||
PyCustomOpDefImpl::op_invoker.reset();
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -3,16 +3,17 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "ocos.h"
|
||||
|
||||
struct PyCustomOpDef {
|
||||
std::string op_type;
|
||||
uint64_t obj_id = 0;
|
||||
std::vector<int> input_types;
|
||||
std::vector<int> output_types;
|
||||
std::vector<std::string> attrs;
|
||||
std::map<std::string, int> attrs;
|
||||
|
||||
static void AddOp(const PyCustomOpDef* cod);
|
||||
|
||||
|
@ -37,7 +38,7 @@ struct PyCustomOpDef {
|
|||
};
|
||||
|
||||
struct PyCustomOpKernel {
|
||||
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info, uint64_t id, const std::vector<std::string>& attrs);
|
||||
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info, uint64_t id, const std::map<std::string, int>& attrs);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -76,10 +77,6 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
|
|||
return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
|
||||
};
|
||||
|
||||
const std::vector<std::string>& GetAttributesNames() const {
|
||||
return opdef_->attrs;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return opdef_->output_types.size();
|
||||
};
|
||||
|
@ -93,5 +90,4 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
|
|||
std::string op_domain_;
|
||||
};
|
||||
|
||||
|
||||
bool EnablePyCustomOps(bool enable = true);
|
||||
|
|
|
@ -5,7 +5,8 @@ import onnxruntime as _ort
|
|||
from onnx import helper, onnx_pb as onnx_proto
|
||||
from transformers import GPT2Tokenizer
|
||||
from onnxruntime_extensions import (
|
||||
util,
|
||||
PyCustomOpDef,
|
||||
onnx_op, util,
|
||||
make_onnx_model,
|
||||
enable_py_op,
|
||||
get_library_path as _get_library_path)
|
||||
|
@ -75,14 +76,18 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
cls.merges = util.get_test_data_file('data', 'gpt2.merges.txt')
|
||||
cls.tokenizer = MyGPT2Tokenizer(cls.tokjson, cls.merges)
|
||||
|
||||
# @onnx_op(op_type="GPT2Tokenizer",
|
||||
# inputs=[PyCustomOpDef.dt_string],
|
||||
# outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_int64],
|
||||
# attrs=["padding_length"])
|
||||
# def bpe_tokenizer(s, **kwargs):
|
||||
# padding_length = kwargs["padding_length"]
|
||||
# input_ids, attention_mask = cls.tokenizer.tokenizer_sentence(s, padding_length)
|
||||
# return input_ids, attention_mask
|
||||
@onnx_op(op_type="GPT2Tokenizer",
|
||||
inputs=[PyCustomOpDef.dt_string],
|
||||
outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_int64],
|
||||
attrs={"padding_length": PyCustomOpDef.dt_int64})
|
||||
def bpe_tokenizer(s, **kwargs):
|
||||
padding_length = kwargs["padding_length"]
|
||||
input_ids, attention_mask = cls.tokenizer.tokenizer_sentence([s[0]], padding_length)
|
||||
return input_ids, attention_mask
|
||||
|
||||
def tearDown(self) -> None:
|
||||
enable_py_op(True)
|
||||
return super().tearDown()
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True)
|
||||
|
@ -95,9 +100,6 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
np.testing.assert_array_equal(expect_input_ids, input_ids)
|
||||
np.testing.assert_array_equal(expect_attention_mask, attention_mask)
|
||||
|
||||
del sess
|
||||
del so
|
||||
|
||||
def test_tokenizer(self):
|
||||
enable_py_op(False)
|
||||
|
||||
|
@ -112,9 +114,9 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
self._run_tokenizer(["I can feel the magic, can you?", "Yes I do."])
|
||||
self._run_tokenizer(["I can feel the magic, can you?", "Yes I do."], 100)
|
||||
|
||||
enable_py_op(True)
|
||||
|
||||
def test_optional_outputs(self):
|
||||
enable_py_op(False)
|
||||
|
||||
# Test for model without attention mask (input id output is always required)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False)
|
||||
so = _ort.SessionOptions()
|
||||
|
@ -132,15 +134,15 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
np.testing.assert_array_equal(expect_input_ids, outputs[0])
|
||||
|
||||
|
||||
# def test_tokenizer_pyop(self):
|
||||
# self._run_tokenizer(["I can feel the magic, can you?"])
|
||||
# self._run_tokenizer(["Hey Cortana"])
|
||||
# self._run_tokenizer(["你好123。david"])
|
||||
# self._run_tokenizer(["爱你一三一四"])
|
||||
# self._run_tokenizer(["women'thinsulate 3 button leather car co"])
|
||||
# self._run_tokenizer(["#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ"])
|
||||
# self._run_tokenizer(["ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ"])
|
||||
# self._run_tokenizer(["⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏"])
|
||||
def test_tokenizer_pyop(self):
|
||||
self._run_tokenizer(["I can feel the magic, can you?"])
|
||||
self._run_tokenizer(["Hey Cortana"])
|
||||
self._run_tokenizer(["你好123。david"])
|
||||
self._run_tokenizer(["爱你一三一四"])
|
||||
self._run_tokenizer(["women'thinsulate 3 button leather car co"])
|
||||
self._run_tokenizer(["#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ"])
|
||||
self._run_tokenizer(["ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ"])
|
||||
self._run_tokenizer(["⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче