Handles dummy python operators for double and strings (#7)
* refactor tests * Update mshost.yaml * Implements dummy operators with double and strings * udpate CI * Implements StringUpper C++ version * Fix runtime issue preventing from registering multiple python ops * add c++ operator StringJoin * Support multi output for python and C++ operators * remove torch in requirements-dev.txt
This commit is contained in:
Родитель
3d13eb867c
Коммит
e36205ee83
|
@ -17,19 +17,37 @@ jobs:
|
|||
- powershell: Write-Host "##vso[task.prependpath]$env:CONDA\Scripts"
|
||||
displayName: Add conda to PATH
|
||||
|
||||
- script: conda create --yes --quiet --name py37 -c conda-forge python=3.7 numpy
|
||||
- script: conda create --yes --quiet --name pyenv -c conda-forge python=3.7 numpy
|
||||
displayName: Create Anaconda environment
|
||||
|
||||
- script: |
|
||||
call activate py37
|
||||
python -m pip install --upgrade pip numpy
|
||||
call activate pyenv
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
displayName: Install requirements.txt
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
echo Test numpy installation... && python -c "import numpy"
|
||||
call .\build.bat
|
||||
displayName: Build the custom-op library
|
||||
|
||||
- script: |
|
||||
call activate py37
|
||||
python -m pip install onnxruntime onnxconverter_common
|
||||
call activate pyenv
|
||||
python -m pip install -e out\Windows\RelWithDebInfo
|
||||
python .\test\test_pyops.py
|
||||
displayName: Install the custom-op library
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements-dev.txt
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pytest test
|
||||
displayName: Run python test
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
|
||||
#include "ocos.h"
|
||||
#include "kernels/kernels.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
|
||||
|
@ -81,6 +83,121 @@ struct KernelTwo {
|
|||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct KernelStringUpper {
|
||||
KernelStringUpper(OrtApi api)
|
||||
: api_(api),
|
||||
ort_(api_) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i];
|
||||
std::transform(out[i].begin(), out[i].end(), out[i].begin(), ::toupper);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct KernelStringJoin {
|
||||
KernelStringJoin(OrtApi api)
|
||||
: api_(api),
|
||||
ort_(api_) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 2)
|
||||
throw std::runtime_error(MakeString("Input 1 must have 2 dimensions but has ", dimensions.size(), "."));
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), 1);
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
int64_t index = 0;
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
std::ostringstream st;
|
||||
for (int64_t j = 0; j < dimensions[1] - 1; ++j, ++index) {
|
||||
st << X[index] << *sep;
|
||||
}
|
||||
st << X[index++];
|
||||
out[i] = st.str();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct KernelNegPos {
|
||||
KernelNegPos(OrtApi api)
|
||||
: api_(api),
|
||||
ort_(api_) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
|
||||
float* out1 = ort_.GetTensorMutableData<float>(output1);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (X[i] > 0) {
|
||||
out0[i] = 0;
|
||||
out1[i] = X[i];
|
||||
} else {
|
||||
out0[i] = X[i];
|
||||
out1[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
Ort::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
|
@ -112,14 +229,58 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|||
|
||||
} c_CustomOpTwo;
|
||||
|
||||
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
return new KernelStringUpper(api);
|
||||
};
|
||||
|
||||
const char* GetName() const { return "StringUpper"; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; };
|
||||
|
||||
} c_CustomOpStringUpper;
|
||||
|
||||
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
return new KernelStringJoin(api);
|
||||
};
|
||||
|
||||
const char* GetName() const { return "StringJoin"; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; };
|
||||
|
||||
} c_CustomOpStringJoin;
|
||||
|
||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
return new KernelNegPos(api);
|
||||
};
|
||||
|
||||
const char* GetName() const { return "NegPos"; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
} c_CustomOpNegPos;
|
||||
|
||||
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
|
||||
if (auto status = ortApi->CreateCustomOpDomain("test.customop", &domain)) {
|
||||
return status;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpOne)) {
|
||||
return status;
|
||||
}
|
||||
|
@ -128,6 +289,18 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
|
|||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpStringUpper)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpStringJoin)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpNegPos)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->AddCustomOpDomain(options, domain)) {
|
||||
return status;
|
||||
}
|
||||
|
@ -142,7 +315,7 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
|
|||
for (auto it = custom_op_list; *it != nullptr; ++it) {
|
||||
auto obj_ptr = (*it)();
|
||||
// TODO: it doesn't make sense ORT needs non-const OrtCustomOp object, will fix in new ORT release
|
||||
OrtCustomOp * op_ptr = const_cast<OrtCustomOp *>(obj_ptr);
|
||||
OrtCustomOp* op_ptr = const_cast<OrtCustomOp*>(obj_ptr);
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, op_ptr)) {
|
||||
return status;
|
||||
}
|
||||
|
@ -150,13 +323,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
|
|||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
size_t count = 0;
|
||||
auto c_ops = FetchPyCustomOps(count);
|
||||
for (size_t n = 0; n < count; ++n){
|
||||
// TODO: it doesn't make sense ORT needs non-const OrtCustomOp object, will fix in new ORT release
|
||||
OrtCustomOp * op_ptr = const_cast<OrtCustomOp *>(c_ops+n);
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, op_ptr)) {
|
||||
const OrtCustomOp* c_ops = FetchPyCustomOps(count);
|
||||
while (c_ops != nullptr) {
|
||||
OrtCustomOp* op_ptr = const_cast<OrtCustomOp*>(c_ops);
|
||||
auto status = ortApi->CustomOpDomain_Add(domain, op_ptr);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
++count;
|
||||
c_ops = FetchPyCustomOps(count);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@ The entry point to onnxruntime custom op library
|
|||
__version__ = "0.0.1"
|
||||
__author__ = "Microsoft"
|
||||
|
||||
from onnxconverter_common.onnx_fx import GraphFunctionType as Types
|
||||
from onnxconverter_common.onnx_fx import GraphFunctionType as Types # noqa
|
||||
|
||||
from ._ocos import get_library_path
|
||||
from ._ocos import Opdef
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
|
|
|
@ -3,9 +3,8 @@
|
|||
# license information.
|
||||
###############################################################################
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ._ortcustomops import (PyCustomOpDef, add_custom_op)
|
||||
from ._ortcustomops import PyCustomOpDef, add_custom_op
|
||||
|
||||
|
||||
def get_library_path():
|
||||
|
@ -13,12 +12,10 @@ def get_library_path():
|
|||
return str(pkg_dir / "_ortcustomops.pyd")
|
||||
|
||||
|
||||
class OpdefList:
|
||||
odlist = []
|
||||
|
||||
|
||||
class Opdef:
|
||||
|
||||
_odlist = {}
|
||||
|
||||
def __init__(self, op_type, func):
|
||||
self.op_type = op_type
|
||||
self.body = func
|
||||
|
@ -27,9 +24,9 @@ class Opdef:
|
|||
@staticmethod
|
||||
def declare(*args, **kwargs):
|
||||
if len(args) > 0 and hasattr(args[0], '__call__'):
|
||||
return Opdef._create(args[0])
|
||||
else:
|
||||
return lambda f: Opdef._create(f, *args, **kwargs)
|
||||
raise RuntimeError("Unexpected arguments {}.".format(args))
|
||||
# return Opdef._create(args[0])
|
||||
return lambda f: Opdef._create(f, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _create(func, *args, **kwargs):
|
||||
|
@ -37,18 +34,24 @@ class Opdef:
|
|||
op_type = name or func.__name__
|
||||
opdef = Opdef(op_type, func)
|
||||
od_id = id(opdef)
|
||||
OpdefList.odlist.append(opdef)
|
||||
|
||||
# Tells python this object cannot be destroyed
|
||||
# because it is also stored in C++ container.
|
||||
Opdef._odlist[od_id] = opdef
|
||||
opdef._nativedef = PyCustomOpDef()
|
||||
opdef._nativedef.op_type = op_type
|
||||
opdef._nativedef.obj_id = od_id
|
||||
|
||||
# TODO: add handle more types and multiple inputs/outputs.
|
||||
# by default the op is single in/out
|
||||
if kwargs.get('inputs', None) is None:
|
||||
opdef._nativedef.input_types = [PyCustomOpDef.dt_float]
|
||||
if kwargs.get('outputs', None) is None:
|
||||
opdef._nativedef.output_types = [PyCustomOpDef.dt_float]
|
||||
|
||||
inputs = kwargs.get('inputs', None)
|
||||
if inputs is None:
|
||||
inputs = [PyCustomOpDef.dt_float]
|
||||
opdef._nativedef.input_types = inputs
|
||||
outputs = kwargs.get('outputs', None)
|
||||
if outputs is None:
|
||||
outputs = [PyCustomOpDef.dt_float]
|
||||
opdef._nativedef.output_types = outputs
|
||||
add_custom_op(opdef._nativedef)
|
||||
return opdef
|
||||
|
||||
|
@ -57,15 +60,22 @@ class Opdef:
|
|||
|
||||
|
||||
def _on_pyop_invocation(k_id, feed):
|
||||
for op_ in OpdefList.odlist:
|
||||
if op_._nativedef.obj_id == k_id:
|
||||
rv = op_.body(*feed)
|
||||
return k_id, rv.shape, rv.flatten().tolist()
|
||||
|
||||
# return a dummy result if there is no function found,
|
||||
# an exception should be raised in C++ custom op implementation.
|
||||
fetch = np.ones([1, 1], np.float32)
|
||||
return 0, fetch.shape, fetch.flatten().tolist()
|
||||
if k_id not in Opdef._odlist:
|
||||
raise RuntimeError(
|
||||
"Unable to find function id={}. "
|
||||
"Did you decorate the operator with @onnx_op?.".format(k_id))
|
||||
op_ = Opdef._odlist[k_id]
|
||||
rv = op_.body(*feed)
|
||||
if isinstance(rv, tuple):
|
||||
# Multiple outputs.
|
||||
res = []
|
||||
for r in rv:
|
||||
res.append(r.shape)
|
||||
res.append(r.flatten().tolist())
|
||||
res = tuple(res)
|
||||
else:
|
||||
res = (rv.shape, rv.flatten().tolist())
|
||||
return (k_id, ) + res
|
||||
|
||||
|
||||
PyCustomOpDef.install_hooker(_on_pyop_invocation)
|
||||
|
|
|
@ -16,43 +16,124 @@
|
|||
#include <pybind11/numpy.h>
|
||||
#include <thread>
|
||||
|
||||
#include "../utils.h"
|
||||
#include "pykernel.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
const std::map<int, int>& PyCustomOpDef::get_numpy_type_map(bool from_or_to) {
|
||||
static std::map<int, int> to_type_map{
|
||||
{dt_bool, NPY_BOOL},
|
||||
{dt_float, NPY_FLOAT},
|
||||
{dt_float16, NPY_FLOAT16},
|
||||
{dt_double, NPY_DOUBLE},
|
||||
{dt_int8, NPY_INT8},
|
||||
{dt_uint8, NPY_UINT8},
|
||||
{dt_int16, NPY_INT16},
|
||||
{dt_uint16, NPY_UINT16},
|
||||
{dt_int32, NPY_INT},
|
||||
{dt_uint32, NPY_UINT},
|
||||
{dt_int64, NPY_LONGLONG},
|
||||
{dt_uint64, NPY_ULONGLONG},
|
||||
};
|
||||
static int to_numpy(ONNXTensorElementDataType dt) {
|
||||
switch (dt) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return NPY_FLOAT;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return NPY_UINT8;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return NPY_INT8;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
return NPY_UINT16;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
return NPY_INT16;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return NPY_INT32;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return NPY_INT64;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return NPY_BOOL;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
return NPY_FLOAT16;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return NPY_DOUBLE;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
return NPY_UINT32;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
return NPY_UINT64;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
|
||||
return NPY_COMPLEX64;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
|
||||
return NPY_COMPLEX128;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
return NPY_OBJECT;
|
||||
default:
|
||||
throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
|
||||
}
|
||||
}
|
||||
|
||||
static auto from_type_map = [] {std::map<int, int> reversed;
|
||||
for(auto it:to_type_map) reversed[it.second] = it.first; return reversed; }();
|
||||
static size_t element_size(ONNXTensorElementDataType dt) {
|
||||
switch (dt) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return sizeof(float);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return sizeof(uint8_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return sizeof(int8_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
return sizeof(uint16_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
return sizeof(int16_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return sizeof(int32_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return sizeof(int64_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return sizeof(bool);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
return sizeof(uint16_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return sizeof(double);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
return sizeof(uint32_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
return sizeof(uint64_t);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
|
||||
return sizeof(_C_float_complex);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
|
||||
return sizeof(_C_double_complex);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
return sizeof(std::string*);
|
||||
default:
|
||||
throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
|
||||
}
|
||||
}
|
||||
|
||||
return from_or_to ? from_type_map : to_type_map;
|
||||
static ONNXTensorElementDataType from_numpy(int dt) {
|
||||
switch (dt) {
|
||||
case NPY_FLOAT:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
case NPY_UINT8:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
case NPY_INT8:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
||||
case NPY_UINT16:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
||||
case NPY_INT16:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
||||
case NPY_INT32:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
case NPY_INT64:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case NPY_BOOL:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
case NPY_FLOAT16:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
||||
case NPY_DOUBLE:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
case NPY_UINT32:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
||||
case NPY_UINT64:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
||||
case NPY_COMPLEX64:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
|
||||
case NPY_COMPLEX128:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
throw std::runtime_error("No corresponding ONNX data type/Tensor data Type.");
|
||||
}
|
||||
}
|
||||
|
||||
struct PyCustomOpDefImpl : public PyCustomOpDef {
|
||||
static int to_numpy(int dt, bool from_or_to = false) {
|
||||
auto type_map = get_numpy_type_map(from_or_to);
|
||||
const auto it = type_map.find(dt);
|
||||
if (it == type_map.end()) {
|
||||
throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
typedef std::vector<int64_t> shape_t;
|
||||
static int64_t calc_size_from_shape(const shape_t& sp) {
|
||||
size_t c = 1;
|
||||
|
@ -62,25 +143,28 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
|||
return c;
|
||||
}
|
||||
|
||||
static int from_numpy(int dt) {
|
||||
return to_numpy(dt, true);
|
||||
}
|
||||
|
||||
template <typename _DT>
|
||||
static py::object BuildPyObjFromTensor(const _DT* p, const shape_t& shape) {
|
||||
static py::object BuildPyObjFromTensor(const void* p, const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||
std::vector<npy_intp> npy_dims;
|
||||
for (auto n : shape) {
|
||||
npy_dims.push_back(n);
|
||||
}
|
||||
|
||||
const int numpy_type = to_numpy(dt_float);
|
||||
auto obj = py::reinterpret_borrow<py::object>(PyArray_SimpleNew(
|
||||
const int numpy_type = to_numpy(dtype);
|
||||
py::object obj = py::reinterpret_steal<py::object>(PyArray_SimpleNew(
|
||||
static_cast<int>(shape.size()), npy_dims.data(), numpy_type));
|
||||
|
||||
void* outPtr = static_cast<void*>(
|
||||
void* out_ptr = static_cast<void*>(
|
||||
PyArray_DATA(reinterpret_cast<PyArrayObject*>(obj.ptr())));
|
||||
|
||||
memcpy(outPtr, p, sizeof(_DT) * calc_size_from_shape(shape));
|
||||
if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
py::object* outObj = static_cast<py::object*>(out_ptr);
|
||||
auto size = calc_size_from_shape(shape);
|
||||
const std::string* src = (const std::string*)p;
|
||||
for (int i = 0; i < size; i++, src++) {
|
||||
outObj[i] = py::cast(*src);
|
||||
}
|
||||
} else {
|
||||
size_t size_type = element_size(dtype);
|
||||
memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
@ -98,21 +182,32 @@ std::auto_ptr<PyCustomOpDefImpl::callback_t> PyCustomOpDefImpl::op_invoker;
|
|||
// static std::condition_variable op_cv;
|
||||
// static bool is_ready = false;
|
||||
|
||||
typedef struct {
|
||||
const OrtValue* input_X;
|
||||
ONNXTensorElementDataType dtype;
|
||||
std::vector<int64_t> dimensions;
|
||||
} InputInformation;
|
||||
|
||||
void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
||||
// std::unique_lock<std::mutex> lck(op_mutex);
|
||||
// is_ready = true;
|
||||
// op_cv.notify_all();
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||
size_t n_inputs = ort_.KernelContext_GetInputCount(context);
|
||||
size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
|
||||
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
std::vector<int64_t> dimensions;
|
||||
OrtTensorTypeAndShapeInfo* info = ort_.GetTensorTypeAndShape(input_X);
|
||||
dimensions = (ort_.GetTensorShape(info));
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
std::vector<InputInformation> inputs;
|
||||
inputs.reserve(n_inputs);
|
||||
for (size_t index = 0; index < n_inputs; ++index) {
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, index);
|
||||
std::vector<int64_t> i_dimensions;
|
||||
OrtTensorTypeAndShapeInfo* i_info = ort_.GetTensorTypeAndShape(input_X);
|
||||
i_dimensions = ort_.GetTensorShape(i_info);
|
||||
ONNXTensorElementDataType i_dtype = ort_.GetTensorElementType(i_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(i_info);
|
||||
inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
|
||||
}
|
||||
|
||||
/* Acquire GIL before calling Python code, due to it was released in sess.run */
|
||||
py::gil_scoped_acquire acquire;
|
||||
|
@ -131,18 +226,56 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
// sizeof(float)});
|
||||
|
||||
{
|
||||
py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(X, dimensions);
|
||||
auto feed = py::make_tuple(input0);
|
||||
py::tuple fetch = PyCustomOpDefImpl::InvokePyFunction(obj_id_, feed);
|
||||
py::list pyinputs;
|
||||
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
|
||||
py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
|
||||
(const void*)ort_.GetTensorData<float>(it->input_X), it->dimensions, it->dtype);
|
||||
pyinputs.append(input0);
|
||||
}
|
||||
|
||||
// Call python function id, shape, flat coefficient.
|
||||
py::tuple fetch = PyCustomOpDefImpl::InvokePyFunction(obj_id_, pyinputs);
|
||||
int64_t rid = fetch[0].cast<int64_t>();
|
||||
assert(rid == obj_id_);
|
||||
auto dims = fetch[1].cast<std::vector<int64_t>>();
|
||||
auto retval = fetch[2].cast<std::vector<float>>();
|
||||
|
||||
// Setup output.
|
||||
for (size_t no = 0; no < n_outputs; ++no) {
|
||||
auto dims = fetch[1 + no * 2].cast<std::vector<int64_t>>();
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
|
||||
OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
|
||||
ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
|
||||
const void* Y = (const void*)ort_.GetTensorData<float>(output);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(o_info);
|
||||
void* out = (void*)ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
auto retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
|
||||
std::string* type_outPtr = (std::string*)out;
|
||||
std::string* end = type_outPtr + retval.size();
|
||||
const std::string* source = (const std::string*)retval.data();
|
||||
for (; type_outPtr != end; ++type_outPtr, ++source) {
|
||||
*type_outPtr = *source;
|
||||
}
|
||||
} else {
|
||||
py::array retval = fetch[2 + no * 2].cast<py::array>();
|
||||
if (element_size(o_dtype) != retval.itemsize()) {
|
||||
switch (o_dtype) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
retval = fetch[2 + no * 2].cast<py::array_t<float>>();
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(MakeString(
|
||||
"Type mismatch between declared output element size (",
|
||||
element_size(o_dtype), ") and python element size (",
|
||||
retval.itemsize(), ")"));
|
||||
}
|
||||
}
|
||||
size_t size = element_size(o_dtype);
|
||||
memcpy(out, retval.data(), size * retval.size());
|
||||
}
|
||||
}
|
||||
|
||||
py::gil_scoped_release release;
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dims.data(), dims.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
std::copy(retval.data(), retval.data()+retval.size(), out);
|
||||
|
||||
// TODO: the return value from the python callback function doesn't work in pybind11&numpy.
|
||||
// py::gil_scoped_acquire acquire;
|
||||
|
@ -170,16 +303,29 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list() {
|
||||
static std::vector<PyCustomOpFactory> lst_custom_opdef;
|
||||
return lst_custom_opdef;
|
||||
}
|
||||
|
||||
void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
||||
// No need to protect against concurrent access, GIL is doing that.
|
||||
PyCustomOpDef_python_operator_list().push_back(PyCustomOpFactory(cod));
|
||||
}
|
||||
|
||||
const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count) {
|
||||
// The result must stay alive
|
||||
std::vector<PyCustomOpFactory>& copy = PyCustomOpDef_python_operator_list();
|
||||
if (count < copy.size())
|
||||
return &(copy[count]);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const OrtCustomOp* FetchPyCustomOps(size_t& count) {
|
||||
static std::vector<PyCustomOpFactory> c_pycustomops;
|
||||
c_pycustomops.clear();
|
||||
|
||||
for (auto od_ptr : PyCustomOpDef::FullList()) {
|
||||
c_pycustomops.emplace_back(PyCustomOpFactory(od_ptr));
|
||||
}
|
||||
|
||||
count = c_pycustomops.size();
|
||||
return c_pycustomops.data();
|
||||
auto ptr = PyCustomOpDef_FetchPyCustomOps(count);
|
||||
if (ptr == nullptr)
|
||||
return nullptr;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// static std::ofstream logger;
|
||||
|
@ -191,7 +337,7 @@ static int init_numpy() {
|
|||
}
|
||||
|
||||
void AddGlobalMethods(pybind11::module& m) {
|
||||
m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::FullList().push_back(&cod); });
|
||||
m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); });
|
||||
}
|
||||
|
||||
void AddObjectMethods(pybind11::module& m) {
|
||||
|
|
|
@ -13,10 +13,7 @@ struct PyCustomOpDef {
|
|||
std::vector<int> input_types;
|
||||
std::vector<int> output_types;
|
||||
|
||||
static std::vector<const PyCustomOpDef*>& FullList() {
|
||||
static std::vector<const PyCustomOpDef*> lst_custom_opdef;
|
||||
return lst_custom_opdef;
|
||||
}
|
||||
static void AddOp(const PyCustomOpDef* cod);
|
||||
|
||||
static const int undefined = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
static const int dt_float = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // maps to c type float
|
||||
|
@ -35,21 +32,16 @@ struct PyCustomOpDef {
|
|||
static const int dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; // complex with float32 real and imaginary components
|
||||
static const int dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; // complex with float64 real and imaginary components
|
||||
static const int dt_bfloat16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; // Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
|
||||
static const std::map<int, int>& get_numpy_type_map(bool from_or_to);
|
||||
};
|
||||
|
||||
struct PyCustomOpKernel {
|
||||
PyCustomOpKernel(OrtApi api)
|
||||
PyCustomOpKernel(OrtApi api, uint64_t id)
|
||||
: api_(api),
|
||||
ort_(api_),
|
||||
obj_id_(0) {
|
||||
obj_id_(id) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
void set_opdef_id(uint64_t id) {
|
||||
obj_id_ = id;
|
||||
}
|
||||
|
||||
private:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
|
@ -58,35 +50,38 @@ struct PyCustomOpKernel {
|
|||
};
|
||||
|
||||
struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
|
||||
PyCustomOpFactory(PyCustomOpDef const* opdef) {
|
||||
PyCustomOpFactory(const PyCustomOpDef* opdef) {
|
||||
if (opdef == nullptr)
|
||||
throw std::runtime_error("Python definition is empty.");
|
||||
opdef_ = opdef;
|
||||
}
|
||||
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
auto kernel = new PyCustomOpKernel(api);
|
||||
kernel->set_opdef_id(opdef_ == nullptr? uint64_t(0):opdef_->obj_id);
|
||||
return kernel;
|
||||
return new PyCustomOpKernel(api, opdef_->obj_id);
|
||||
};
|
||||
|
||||
const char* GetName() const {
|
||||
return opdef_ == nullptr ? "Unknown" : opdef_->op_type.c_str();
|
||||
return opdef_->op_type.c_str();
|
||||
};
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return opdef_ == nullptr ? 1 : opdef_->input_types.size();
|
||||
return opdef_->input_types.size();
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t idx) const {
|
||||
return opdef_ == nullptr ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT : static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
|
||||
return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
|
||||
};
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return opdef_ == nullptr ? 1 : opdef_->output_types.size();
|
||||
return opdef_->output_types.size();
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t idx) const {
|
||||
return opdef_ == nullptr ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT : static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
|
||||
};
|
||||
return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
|
||||
}
|
||||
|
||||
PyCustomOpDef const* opdef_ = nullptr;
|
||||
const PyCustomOpDef* opdef_;
|
||||
};
|
||||
|
||||
std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list();
|
||||
const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count);
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
ss << t;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
|
||||
MakeStringInternal(ss, t);
|
||||
MakeStringInternal(ss, args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
std::string MakeString(const Args&... args) {
|
||||
std::ostringstream ss;
|
||||
MakeStringInternal(ss, args...);
|
||||
return std::string(ss.str());
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
pytest
|
|
@ -1,3 +1,4 @@
|
|||
numpy
|
||||
onnx
|
||||
onnxconverter_common
|
||||
onnxruntime
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
from onnx import helper, onnx_pb as onnx_proto
|
||||
from numpy.testing import assert_almost_equal
|
||||
from onnx import helper, onnx_pb as onnx_proto, load
|
||||
import onnxruntime as _ort
|
||||
from ortcustomops import (
|
||||
onnx_op,
|
||||
onnx_op, PyCustomOpDef,
|
||||
get_library_path as _get_library_path)
|
||||
|
||||
|
||||
|
@ -13,34 +16,142 @@ def _create_test_model():
|
|||
['identity1'], ['reversed'],
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
input0 = helper.make_tensor_value_info('input_1', onnx_proto.TensorProto.FLOAT, [None, 2])
|
||||
output0 = helper.make_tensor_value_info('reversed', onnx_proto.TensorProto.FLOAT, [None, 2])
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input_1', onnx_proto.TensorProto.FLOAT, [None, 2])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'reversed', onnx_proto.TensorProto.FLOAT, [None, 2])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid('ai.onnx.contrib', 1)])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid('ai.onnx.contrib', 1)])
|
||||
return model
|
||||
|
||||
|
||||
@onnx_op(op_type="ReverseMatrix")
|
||||
def reverse_matrix(x):
|
||||
# the user custom op implementation here:
|
||||
return np.flip(x, axis=0)
|
||||
def _create_test_model_double(domain):
|
||||
nodes = []
|
||||
nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
|
||||
nodes[1:] = [helper.make_node('AddEpsilon',
|
||||
['identity1'], ['customout'],
|
||||
domain=domain)]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input_1', onnx_proto.TensorProto.DOUBLE, [None, None])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.DOUBLE, [None, None])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
||||
|
||||
# TODO: refactor the following code into pytest cases, right now, the script is more friendly for debugging.
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
def _create_test_model_2outputs(domain):
|
||||
nodes = [
|
||||
helper.make_node('Identity', ['x'], ['identity1']),
|
||||
helper.make_node(
|
||||
'NegPos', ['identity1'], ['neg', 'pos'], domain=domain)
|
||||
]
|
||||
|
||||
sess0 = _ort.InferenceSession('./test/data/custom_op_test.onnx', so)
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'x', onnx_proto.TensorProto.FLOAT, [])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
'neg', onnx_proto.TensorProto.FLOAT, [])
|
||||
output2 = helper.make_tensor_value_info(
|
||||
'pos', onnx_proto.TensorProto.FLOAT, [])
|
||||
|
||||
res = sess0.run(None, {
|
||||
'input_1': np.random.rand(3, 5).astype(np.float32), 'input_2': np.random.rand(3, 5).astype(np.float32)})
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output1, output2])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
||||
print(res[0])
|
||||
|
||||
sess = _ort.InferenceSession(_create_test_model().SerializeToString(), so)
|
||||
class TestPythonOp(unittest.TestCase):
|
||||
|
||||
txout = sess.run(None, {
|
||||
'input_1': np.array([1, 2, 3, 4, 5, 6]).astype(np.float32).reshape([3, 2])})
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
print(txout[0])
|
||||
@onnx_op(op_type="ReverseMatrix")
|
||||
def reverse_matrix(x):
|
||||
# The user custom op implementation here.
|
||||
return np.flip(x, axis=0).astype(np.float32)
|
||||
|
||||
@onnx_op(op_type="AddEpsilon",
|
||||
inputs=[PyCustomOpDef.dt_double],
|
||||
outputs=[PyCustomOpDef.dt_double])
|
||||
def add_epsilon(x):
|
||||
# The user custom op implementation here.
|
||||
return x + 1e-3
|
||||
|
||||
@onnx_op(op_type="NegPos",
|
||||
inputs=[PyCustomOpDef.dt_float],
|
||||
outputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_float])
|
||||
def negpos(x):
|
||||
neg = x.copy()
|
||||
pos = x.copy()
|
||||
neg[x > 0] = 0
|
||||
pos[x < 0] = 0
|
||||
return neg, pos
|
||||
|
||||
def test_python_operator(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model()
|
||||
self.assertIn('op_type: "ReverseMatrix"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array(
|
||||
[1, 2, 3, 4, 5, 6]).astype(np.float32).reshape([3, 2])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
assert_almost_equal(txout[0], np.array([[5., 6.], [3., 4.], [1., 2.]]))
|
||||
|
||||
def test_add_epsilon_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_double('ai.onnx.contrib')
|
||||
self.assertIn('op_type: "AddEpsilon"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([[0., 1., 1.5], [7., 8., -5.5]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
diff = txout[0] - input_1 - 1e-3
|
||||
assert_almost_equal(diff, np.zeros(diff.shape))
|
||||
|
||||
def test_python_negpos(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_2outputs('ai.onnx.contrib')
|
||||
self.assertIn('op_type: "NegPos"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
|
||||
neg, pos = sess.run(None, {'x': x})
|
||||
diff = x - (neg + pos)
|
||||
assert_almost_equal(diff, np.zeros(diff.shape))
|
||||
|
||||
def test_cc_negpos(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_2outputs('test.customop')
|
||||
self.assertIn('op_type: "NegPos"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
|
||||
neg, pos = sess.run(None, {'x': x})
|
||||
diff = x - (neg + pos)
|
||||
assert_almost_equal(diff, np.zeros(diff.shape))
|
||||
|
||||
def test_cc_operator(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
|
||||
this = os.path.dirname(__file__)
|
||||
filename = os.path.join(this, 'data', 'custom_op_test.onnx')
|
||||
onnx_content = load(filename)
|
||||
self.assertIn('op_type: "CustomOpOne"', str(onnx_content))
|
||||
sess0 = _ort.InferenceSession(filename, so)
|
||||
|
||||
res = sess0.run(None, {
|
||||
'input_1': np.random.rand(3, 5).astype(np.float32),
|
||||
'input_2': np.random.rand(3, 5).astype(np.float32)})
|
||||
|
||||
self.assertEqual(res[0].shape, (3, 5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# coding: utf-8
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnx import helper, onnx_pb as onnx_proto
|
||||
import onnxruntime as _ort
|
||||
from ortcustomops import (
|
||||
onnx_op, PyCustomOpDef,
|
||||
get_library_path as _get_library_path)
|
||||
|
||||
|
||||
def _create_test_model_string_upper(domain):
|
||||
nodes = []
|
||||
nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
|
||||
nodes[1:] = [helper.make_node('StringUpper',
|
||||
['identity1'], ['customout'],
|
||||
domain=domain)]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input_1', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
||||
|
||||
def _create_test_model_string_join(domain):
|
||||
nodes = []
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['text'], ['identity1']))
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['sep'], ['identity2']))
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
'StringJoin', ['identity1', 'identity2'],
|
||||
['customout'], domain=domain))
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'text', onnx_proto.TensorProto.STRING, [None, None])
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'sep', onnx_proto.TensorProto.STRING, [1])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
||||
|
||||
class TestPythonOpString(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@onnx_op(op_type="StringUpper",
|
||||
inputs=[PyCustomOpDef.dt_string],
|
||||
outputs=[PyCustomOpDef.dt_string])
|
||||
def string_upper(x):
|
||||
# The user custom op implementation here.
|
||||
return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
|
||||
|
||||
@onnx_op(op_type="StringJoin",
|
||||
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
|
||||
outputs=[PyCustomOpDef.dt_string])
|
||||
def string_join(x, sep):
|
||||
# The user custom op implementation here.
|
||||
if sep.shape != (1, ):
|
||||
raise RuntimeError(
|
||||
"Unexpected shape {} for 'sep'.".format(sep.shape))
|
||||
sp = sep[0]
|
||||
return np.array([sp.join(row) for row in x])
|
||||
|
||||
def test_check_types(self):
|
||||
def_list = set(dir(PyCustomOpDef))
|
||||
type_list = [
|
||||
# 'dt_bfloat16',
|
||||
'dt_bool',
|
||||
'dt_complex128',
|
||||
'dt_complex64',
|
||||
'dt_double',
|
||||
'dt_float',
|
||||
'dt_float16',
|
||||
'dt_int16',
|
||||
'dt_int32',
|
||||
'dt_int64',
|
||||
'dt_int8',
|
||||
'dt_string',
|
||||
'dt_uint16',
|
||||
'dt_uint32',
|
||||
'dt_uint64',
|
||||
'dt_uint8']
|
||||
for t in type_list:
|
||||
self.assertIn(t, def_list)
|
||||
|
||||
def test_string_upper_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('test.customop')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
|
||||
|
||||
def test_string_upper_cc_accent(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('test.customop')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([["Abcé"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABCé"]]).tolist())
|
||||
|
||||
def test_string_upper_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('ai.onnx.contrib')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
|
||||
|
||||
def test_string_upper_python_accent(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('ai.onnx.contrib')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([["Abcé"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(),
|
||||
np.array([["ABCé".upper()]]).tolist())
|
||||
|
||||
def test_string_join_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('ai.onnx.contrib')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])])
|
||||
self.assertEqual(text.shape, (2, 3))
|
||||
sep = np.array([";"])
|
||||
txout = sess.run(None, {'text': text, 'sep': sep})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
|
||||
|
||||
def test_string_join_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('test.customop')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])])
|
||||
sep = np.array([";"])
|
||||
txout = sess.run(None, {'text': text, 'sep': sep})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,3 +1,5 @@
|
|||
import unittest
|
||||
from onnx import load
|
||||
import torch
|
||||
import onnxruntime as _ort
|
||||
import io
|
||||
|
@ -9,41 +11,55 @@ from ortcustomops import (
|
|||
get_library_path as _get_library_path)
|
||||
|
||||
|
||||
@onnx_op(op_type="Inverse")
|
||||
def inverse(x):
|
||||
# the user custom op implementation here:
|
||||
return numpy.linalg.inv(x)
|
||||
|
||||
|
||||
def my_inverse(g, self):
|
||||
return g.op("ai.onnx.contrib::Inverse", self)
|
||||
|
||||
|
||||
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
|
||||
register_custom_op_symbolic('::inverse', my_inverse, 1)
|
||||
|
||||
|
||||
class CustomInverse(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.inverse(x) + x
|
||||
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
class TestPyTorchCustomOp(unittest.TestCase):
|
||||
|
||||
# Export model to ONNX
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(CustomInverse(), (x,), f)
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
model = CustomInverse()
|
||||
pt_outputs = model(x)
|
||||
@onnx_op(op_type="Inverse")
|
||||
def inverse(x):
|
||||
# the user custom op implementation here:
|
||||
return numpy.linalg.inv(x)
|
||||
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
def test_custom_pythonop_pytorch(self):
|
||||
|
||||
# Run the exported model with ONNX Runtime
|
||||
ort_sess = _ort.InferenceSession(f.getvalue(), so)
|
||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
|
||||
ort_outputs = ort_sess.run(None, ort_inputs)
|
||||
# register_custom_op_symbolic(
|
||||
# '<namespace>::inverse', my_inverse, <opset_version>)
|
||||
register_custom_op_symbolic('::inverse', my_inverse, 1)
|
||||
|
||||
# Validate PyTorch and ONNX Runtime results
|
||||
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)
|
||||
x = torch.randn(3, 3)
|
||||
|
||||
# Export model to ONNX
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(CustomInverse(), (x,), f)
|
||||
onnx_model = load(io.BytesIO(f.getvalue()))
|
||||
self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
|
||||
|
||||
model = CustomInverse()
|
||||
pt_outputs = model(x)
|
||||
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
|
||||
# Run the exported model with ONNX Runtime
|
||||
ort_sess = _ort.InferenceSession(f.getvalue(), so)
|
||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy())
|
||||
for i, input in enumerate((x,)))
|
||||
ort_outputs = ort_sess.run(None, ort_inputs)
|
||||
|
||||
# Validate PyTorch and ONNX Runtime results
|
||||
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(),
|
||||
ort_outputs[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче