onnxruntime-extensions/include/custom_op/custom_op_lite.h

1036 строки
42 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <optional>
#include <numeric>
#include <string>
#include <string_view>
#include "tensor_api.h"
#include "ort_c_to_cpp.h"
#include "onnxruntime_f16.h"
namespace Ort {
namespace Custom {
class OrtKernelContextStorage : public ITensorStorage {
public:
OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api_.KernelContext_GetInput(&ctx, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
}
}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
virtual bool IsInitialized() const override {
return shape_.has_value();
}
const void* DataRaw() const override {
return api_.GetTensorRawData(const_value_);
}
void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
if (!const_value_) {
const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
shape_ = shape;
}
return api_.GetTensorMutableRawData(const_cast<OrtValue*>(const_value_));
}
void* Release() override {
ORTX_CXX_API_THROW("Can't release the tensor buffer with ORT graph mode.", ORT_RUNTIME_EXCEPTION);
}
private:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
const OrtValue* const_value_{}; // for input
std::optional<std::vector<int64_t>> shape_;
};
static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) {
std::string output = "Cpu";
if (is_input) {
const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice);
const OrtMemoryInfo* mem_info = {};
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
if (mem_info) {
const char* mem_type = nullptr;
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
if (mem_type) {
output = mem_type;
}
}
}
return output;
}
template <typename T>
class OrtTensor : public Tensor<T> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
}
bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}
private:
std::string mem_type_ = "Cpu";
};
class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
public:
using strings = std::vector<std::string>;
OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
std::vector<char> chars(num_chars + 1, '\0');
// assert((*shape_).size() == 1 || ((*shape_).size() == 2 && (*shape_)[0] == 1));
int64_t num_strings = 1; // string scalar
if ((*shape_).size() > 0) {
num_strings = (*shape_).front();
for (auto iter = (*shape_).begin() + 1; iter != (*shape_).end(); ++iter) {
num_strings *= *iter;
}
}
std::vector<size_t> offsets(num_strings);
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
(void*)chars.data(),
num_chars,
offsets.data(),
offsets.size()));
auto upper_bound = static_cast<int64_t>(num_strings) - 1;
input_strings_.resize(num_strings);
for (int64_t i = upper_bound; i >= 0; --i) {
if (i < upper_bound) {
chars[offsets[i + 1]] = '\0';
}
input_strings_[i] = chars.data() + offsets[i];
}
}
}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
virtual const void* DataRaw() const override {
if (input_strings_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return reinterpret_cast<const void*>(input_strings_[0].c_str());
}
virtual bool IsInitialized() const override {
return shape_.has_value();
}
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
std::vector<const char*> raw;
for (const auto& s : ss) {
raw.push_back(s.data());
}
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
}
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
}
const strings& Data() const override {
return input_strings_;
}
private:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
std::vector<std::string> input_strings_;
std::optional<std::vector<int64_t>> shape_;
};
class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
public:
using strings = std::vector<std::string_view>;
OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
chars_.resize(num_chars + 1, '\0');
size_t num_strings = 1;
if ((*shape_).size() > 0) {
num_strings = static_cast<size_t>((*shape_)[0]);
}
if (num_strings) {
std::vector<size_t> offsets(num_strings);
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
(void*)chars_.data(),
num_chars,
offsets.data(),
offsets.size()));
offsets.push_back(num_chars);
for (size_t i = 0; i < num_strings; ++i) {
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
}
}
}
}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
virtual const void* DataRaw() const override {
if (input_string_views_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return reinterpret_cast<const void*>(input_string_views_[0].data());
}
virtual bool IsInitialized() const override {
return shape_.has_value();
}
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
}
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
}
const strings& Data() const override {
return input_string_views_;
}
private:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
std::vector<char> chars_; // for input
std::vector<std::string_view> input_string_views_; // for input
std::optional<std::vector<int64_t>> shape_;
};
// to make the metaprogramming magic happy.
template <>
class OrtTensor<std::string> : public Tensor<std::string> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}
private:
std::string mem_type_ = "Cpu";
};
template <>
class OrtTensor<std::string_view> : public Tensor<std::string_view> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}
private:
std::string mem_type_ = "Cpu";
};
using TensorPtr = std::unique_ptr<Custom::Arg>;
using TensorPtrs = std::vector<TensorPtr>;
using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
using TensorBasePtrs = std::vector<TensorBasePtr>;
// Represent variadic input or output
struct Variadic : public Arg {
Variadic(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
#if ORT_API_VERSION < 14
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
#endif
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
auto* info = api_.GetTensorTypeAndShape(const_value);
auto type = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
TensorBasePtr tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true);
break;
default:
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
break;
}
tensors_.emplace_back(tensor.release());
} // for
} else {
// a Variadic used for output is populated by the Compute so leave tensors_ empty here
}
}
template <typename T>
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
auto tensor = std::make_unique<OrtTensor<T>>(api_, ctx_, ith_output, false);
auto raw_output = tensor.get()->Allocate(shape);
tensors_.emplace_back(tensor.release());
return raw_output;
}
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
auto tensor = std::make_unique<OrtTensor<std::string>>(api_, ctx_, ith_output, false);
Tensor<std::string>& output = *tensor;
tensors_.emplace_back(tensor.release());
return output;
}
size_t Size() const {
return tensors_.size();
}
const TensorBasePtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
private:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
std::string mem_type_ = "Cpu";
TensorBasePtrs tensors_;
};
#if ORT_API_VERSION >= 17
class OrtGraphKernelContext : public KernelContext {
public:
OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
OrtMemoryInfo* info;
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_));
api_.ReleaseMemoryInfo(info);
}
virtual ~OrtGraphKernelContext() {
if (allocator_) {
api_.ReleaseAllocator(allocator_);
}
}
void* AllocScratchBuffer(size_t size) override {
return allocator_->Alloc(allocator_, size);
}
void FreeScratchBuffer(void* p) override {
if (p) {
allocator_->Free(allocator_, p);
}
}
private:
const OrtApi& api_;
OrtAllocator* allocator_;
};
#endif
#ifdef USE_CUDA
enum CudaResource {
cuda_handle_t = 10000,
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t,
};
#if ORT_API_VERSION >= 17
class OrtGraphCudaKernelContext : public CUDAKernelContext {
public:
static const int cuda_resource_ver = 1;
OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
if (result || !cuda_stream_) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}
result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
if (result || !cublas_) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
void* resource = nullptr;
result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
if (result) {
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
}
memcpy(&device_id_, &resource, sizeof(int));
OrtMemoryInfo* info;
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
api_.ReleaseMemoryInfo(info);
OrtMemoryInfo* cuda_mem_info;
OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
api_.ReleaseMemoryInfo(cuda_mem_info);
}
virtual ~OrtGraphCudaKernelContext() {
if (cpu_allocator_) {
api_.ReleaseAllocator(cpu_allocator_);
}
if (cuda_allocator_) {
api_.ReleaseAllocator(cuda_allocator_);
}
}
void* AllocScratchBuffer(size_t size) override {
return cpu_allocator_->Alloc(cpu_allocator_, size);
}
void FreeScratchBuffer(void* p) override {
if (p) {
cpu_allocator_->Free(cpu_allocator_, p);
}
}
void* AllocCudaScratchBuffer(size_t size) override {
return cuda_allocator_->Alloc(cuda_allocator_, size);
}
void FreeCudaScratchBuffer(void* p) override {
if (p) {
cuda_allocator_->Free(cuda_allocator_, p);
}
}
void* GetCudaStream() const override {
return cuda_stream_;
}
void* GetCublasHandle() const override {
return cublas_;
}
int GetCudaDeviceId() const override {
return device_id_;
}
private:
const OrtApi& api_;
OrtAllocator* cpu_allocator_;
OrtAllocator* cuda_allocator_;
void* cuda_stream_ = {};
void* cublas_ = {};
int device_id_ = 0;
};
#endif
#endif
// using mf16_t = uint16_t;
struct OrtLiteCustomOp : public OrtCustomOp {
// CreateTuple
template <size_t ith_input, size_t ith_output, typename... Ts>
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
return std::make_tuple();
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#if ORT_API_VERSION >= 17
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, KernelContext*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<OrtGraphKernelContext>(api->GetOrtApi(), *context));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#ifdef USE_CUDA
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, CUDAKernelContext*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<OrtGraphCudaKernelContext>(api->GetOrtApi(), *context));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif
#endif
#if ORT_API_VERSION >= 14
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif
#undef data_type_def
#define data_type_def bool
#include "tensor_tuple.inc"
#if ORT_API_VERSION >= 16
#undef data_type_def
#define data_type_def BFloat16
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def MFloat16
#include "tensor_tuple.inc"
#endif
#undef data_type_def
#define data_type_def float
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def double
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def int8_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def int16_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def int32_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def int64_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def uint8_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def uint16_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def uint32_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def uint64_t
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def std::string
#include "tensor_tuple.inc"
#undef data_type_def
#define data_type_def std::string_view
#include "tensor_tuple.inc"
#undef data_type_def
// ParseArgs ...
template <typename... Ts>
static typename std::enable_if<0 == sizeof...(Ts)>::type
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#if ORT_API_VERSION >= 17
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, KernelContext*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#ifdef USE_CUDA
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, CUDAKernelContext*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif
#endif
#if ORT_API_VERSION >= 14
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
#endif
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
input_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
input_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
}
#define PARSE_INPUT(data_type, onnx_type) \
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
PARSE_INPUT_BASE(data_type, onnx_type)
#define PARSE_OUTPUT(data_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
}
#define PARSE_ARGS(data_type, onnx_type) \
PARSE_INPUT(data_type, onnx_type) \
PARSE_OUTPUT(data_type, onnx_type)
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
#if ORT_API_VERSION >= 16
PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
#endif
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
// Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
memset(&this->version, 0, sizeof(OrtCustomOp));
int act_ver = GetActiveOrtAPIVersion();
OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->input_types_.size();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->input_types_[indice];
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->output_types_.size();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->output_types_[indice];
};
#if ORT_API_VERSION >= 14
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
return 1;
};
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
return 0;
};
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
return 1;
};
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
return 0;
};
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
return OrtMemTypeDefault;
};
#else
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
#endif
#if ORT_API_VERSION >= 18
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
return 0;
};
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
#endif
}
const std::string op_name_;
const std::string execution_provider_;
std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;
};
template <typename... Args>
struct OrtLiteCustomFunc : public OrtLiteCustomOp {
using ComputeFn = void (*)(Args...);
using MyType = OrtLiteCustomFunc<Args...>;
struct Kernel {
ComputeFn compute_fn_{};
std::string ep_{};
std::unique_ptr<OrtW::CustomOpApi> api_;
};
OrtLiteCustomFunc(const char* op_name,
const char* execution_provider,
ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
compute_fn_(compute_fn) {
ParseArgs<Args...>(input_types_, output_types_);
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
context,
tensors,
kernel->api_->KernelContext_GetInputCount(context),
kernel->api_->KernelContext_GetOutputCount(context),
kernel->ep_);
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
};
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
kernel->compute_fn_ = self->compute_fn_;
kernel->ep_ = self->execution_provider_;
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
return reinterpret_cast<void*>(kernel.release());
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
delete reinterpret_cast<Kernel*>(op_kernel);
};
}
ComputeFn compute_fn_;
};
class OrtAttributeReader {
public:
OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) {
}
template <class T>
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
return base_kernel_.TryToGetAttributeWithDefault(name, default_value);
}
private:
BaseKernel base_kernel_;
};
template <typename CustomOp>
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
template <typename... Args>
using CustomComputeFn = void (CustomOp::*)(Args...) const;
using MyType = OrtLiteCustomStruct<CustomOp>;
struct Kernel {
std::unique_ptr<CustomOp> custom_op_;
std::string ep_{};
std::unique_ptr<OrtW::CustomOpApi> api_;
};
OrtLiteCustomStruct(const char* op_name,
const char* execution_provider) : OrtLiteCustomOp(op_name,
execution_provider) {
init(&CustomOp::Compute);
}
template <typename... Args>
void init(CustomComputeFn<Args...>) {
ParseArgs<Args...>(input_types_, output_types_);
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
context,
tensors,
kernel->api_->KernelContext_GetInputCount(context),
kernel->api_->KernelContext_GetOutputCount(context),
kernel->ep_);
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
};
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
if constexpr (std::is_constructible<CustomOp, const OrtApi&, const OrtKernelInfo&>::value) {
kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
} else {
kernel->custom_op_ = std::make_unique<CustomOp>(OrtAttributeReader(*ort_api, *info));
}
auto self = static_cast<const MyType*>(this_);
kernel->ep_ = self->execution_provider_;
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
return reinterpret_cast<void*>(kernel.release());
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
delete reinterpret_cast<Kernel*>(op_kernel);
};
}
};
template <typename... Args>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider,
void (*custom_compute_fn)(Args...)) {
using LiteOp = OrtLiteCustomFunc<Args...>;
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
}
template <typename CustomOp>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider) {
using LiteOp = OrtLiteCustomStruct<CustomOp>;
return std::make_unique<LiteOp>(op_name, execution_provider).release();
}
} // namespace Custom
} // namespace Ort