onnxruntime-extensions/pyop/pykernel.h

134 строки
4.2 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <vector>
#include <map>
#include <memory>
struct PyCustomOpDef {
std::string op_type;
uint64_t obj_id = 0;
std::vector<int> input_types;
std::vector<int> output_types;
std::map<std::string, int> attrs;
static void AddOp(const PyCustomOpDef* cod);
// no initializer here to avoid gcc whole-archive
static const int undefined;
static const int dt_float;
static const int dt_uint8;
static const int dt_int8;
static const int dt_uint16;
static const int dt_int16;
static const int dt_int32;
static const int dt_int64;
static const int dt_string;
static const int dt_bool;
static const int dt_float16;
static const int dt_double;
static const int dt_uint32;
static const int dt_uint64;
static const int dt_complex64;
static const int dt_complex128;
static const int dt_bfloat16;
};
struct PyCustomOpKernel {
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info, uint64_t id, const std::map<std::string, int>& attrs);
void Compute(OrtKernelContext* context);
private:
const OrtApi& api_;
OrtW::CustomOpApi ort_;
uint64_t obj_id_;
std::map<std::string, std::string> attrs_values_;
};
struct PyCustomOpFactory : public OrtCustomOp {
PyCustomOpFactory() = default;
PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
if (opdef == nullptr)
throw std::runtime_error("Python definition is empty.");
opdef_ = opdef;
op_domain_ = domain;
op_type_ = op;
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
void* p = nullptr;
OCOS_API_IMPL_BEGIN
auto self = static_cast<const PyCustomOpFactory*>(this_);
auto kernel = std::make_unique<PyCustomOpKernel>(*api, *info, self->opdef_->obj_id, self->opdef_->attrs).release();
p = reinterpret_cast<void*>(kernel);
OCOS_API_IMPL_END
return p;
};
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->op_type_.c_str();
};
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
return "CPUExecutionProvider";
};
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->opdef_->input_types.size();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return static_cast<ONNXTensorElementDataType>(self->opdef_->input_types[index]);
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->opdef_->output_types.size();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return static_cast<ONNXTensorElementDataType>(self->opdef_->output_types[index]);
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) noexcept {
OCOS_API_IMPL_BEGIN
static_cast<PyCustomOpKernel*>(op_kernel)->Compute(context);
OCOS_API_IMPL_END
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) noexcept {
std::unique_ptr<PyCustomOpKernel>(reinterpret_cast<PyCustomOpKernel*>(op_kernel)).reset();
};
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
};
}
const PyCustomOpDef* opdef_ = nullptr;
std::string op_type_;
std::string op_domain_;
};
bool EnablePyCustomOps(bool enable = true);
#if defined(ENABLE_C_API)
void AddGlobalMethodsCApi(pybind11::module& m);
#endif