Родитель
658aa28572
Коммит
ae416f6aa6
|
@ -18,8 +18,8 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
|
|||
const char c_OpDomain[] = "ai.onnx.contrib";
|
||||
|
||||
struct BaseKernel {
|
||||
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
||||
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
||||
|
||||
bool HasAttribute(const char* name) const;
|
||||
|
||||
|
@ -37,7 +37,7 @@ struct BaseKernel {
|
|||
|
||||
protected:
|
||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
const OrtApi& api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo* info_;
|
||||
};
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
|
||||
struct KernelGaussianBlur : BaseKernel {
|
||||
KernelGaussianBlur(OrtApi api) : BaseKernel(api) {
|
||||
KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -78,7 +78,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
|
|||
return "GaussianBlur";
|
||||
}
|
||||
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelGaussianBlur(api);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
|
||||
struct KernelImageReader : BaseKernel {
|
||||
KernelImageReader(OrtApi api) : BaseKernel(api) {
|
||||
KernelImageReader(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -49,7 +49,7 @@ struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageR
|
|||
return "ImageReader";
|
||||
}
|
||||
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelImageReader(api);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
|
||||
struct KernelInverse : BaseKernel {
|
||||
KernelInverse(OrtApi api) : BaseKernel(api) {
|
||||
KernelInverse(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -34,7 +34,7 @@ struct KernelInverse : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelInverse(api);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include "ocos.h"
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(OrtApi api) : BaseKernel(api) {
|
||||
KernelNegPos(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context){
|
||||
|
@ -40,7 +40,7 @@ struct KernelNegPos : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const{
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelNegPos(api);
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "segment_extraction.hpp"
|
||||
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
||||
|
@ -51,7 +51,7 @@ ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*inde
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
void* CustomOpSegmentExtraction::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpSegmentExtraction::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelSegmentExtraction(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentExtraction : BaseKernel {
|
||||
KernelSegmentExtraction(OrtApi api);
|
||||
KernelSegmentExtraction(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,6 @@ struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction,
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -52,7 +52,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
|||
}
|
||||
}
|
||||
|
||||
KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) {
|
||||
KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelSegmentSum::Compute(OrtKernelContext* context) {
|
||||
|
@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) co
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpSegmentSum::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelSegmentSum(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentSum : BaseKernel {
|
||||
KernelSegmentSum(OrtApi api);
|
||||
KernelSegmentSum(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,6 @@ struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentS
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
|
||||
KernelMaskedFill::KernelMaskedFill(OrtApi api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
|
||||
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||
|
@ -51,7 +51,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpMaskedFill::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelMaskedFill(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,14 +8,14 @@
|
|||
#include <unordered_map>
|
||||
|
||||
struct KernelMaskedFill : BaseKernel {
|
||||
KernelMaskedFill(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
#include "op_equal_impl.hpp"
|
||||
#include <string>
|
||||
|
||||
KernelStringEqual::KernelStringEqual(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringEqual::Compute(OrtKernelContext* context) {
|
||||
|
@ -24,7 +24,7 @@ ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) c
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
};
|
||||
|
||||
void* CustomOpStringEqual::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const{
|
||||
void* CustomOpStringEqual::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const{
|
||||
return new KernelStringEqual(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringEqual : BaseKernel {
|
||||
KernelStringEqual(OrtApi api);
|
||||
KernelStringEqual(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,6 @@ struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelString
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "string_tensor.h"
|
||||
#include "op_ragged_tensor.hpp"
|
||||
|
||||
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
|
||||
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
||||
|
@ -53,7 +53,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*i
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
void* CustomOpRaggedTensorToSparse::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpRaggedTensorToSparse::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelRaggedTensorToSparse(api);
|
||||
};
|
||||
|
||||
|
@ -65,7 +65,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||
|
@ -84,7 +84,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
|
|||
return max_col;
|
||||
}
|
||||
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
|
||||
}
|
||||
|
||||
|
@ -134,7 +134,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
void* CustomOpRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelRaggedTensorToDense(api, info);
|
||||
};
|
||||
|
||||
|
@ -146,7 +146,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
|
@ -193,7 +193,7 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
void* CustomOpStringRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringRaggedTensorToDense(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include "ocos.h"
|
||||
|
||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||
KernelRaggedTensorToSparse(OrtApi api);
|
||||
KernelRaggedTensorToSparse(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -15,12 +15,12 @@ struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSp
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct CommonRaggedTensorToDense : BaseKernel {
|
||||
CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
||||
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
|
||||
protected:
|
||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
||||
|
@ -28,7 +28,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
|
|||
};
|
||||
|
||||
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -40,12 +40,12 @@ struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDen
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -54,6 +54,6 @@ struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRagge
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "re2/re2.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringRegexReplace(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringRegexReplace : BaseKernel {
|
||||
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -15,7 +15,7 @@ struct KernelStringRegexReplace : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||
|
@ -79,7 +79,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringRegexSplitWithOffsets(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,12 +8,12 @@
|
|||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||
|
@ -37,7 +37,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, left_value, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringConcat::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringConcat(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringConcat : BaseKernel {
|
||||
KernelStringConcat(OrtApi api);
|
||||
KernelStringConcat(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <regex>
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringECMARegexReplace::KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
|
||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
|
||||
|
@ -70,7 +70,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringECMARegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringECMARegexReplace(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringECMARegexReplace : BaseKernel {
|
||||
KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -16,7 +16,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "string_tensor.h"
|
||||
|
||||
|
||||
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringECMARegexSplitWithOffsets(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -9,14 +9,14 @@
|
|||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
bool ignore_case_;
|
||||
};
|
||||
|
||||
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "string_hash.hpp"
|
||||
|
||||
|
||||
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||
|
@ -43,7 +43,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpStringHash::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringHash::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringHash(api);
|
||||
};
|
||||
|
||||
|
@ -72,7 +72,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||
|
@ -106,7 +106,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpStringHashFast::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringHashFast::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringHashFast(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringHash : BaseKernel {
|
||||
KernelStringHash(OrtApi api);
|
||||
KernelStringHash(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
@ -21,12 +21,12 @@ struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHa
|
|||
};
|
||||
|
||||
struct KernelStringHashFast : BaseKernel {
|
||||
KernelStringHashFast(OrtApi api);
|
||||
KernelStringHashFast(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "string_join.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||
|
@ -86,7 +86,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, out, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringJoin::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringJoin(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringJoin : BaseKernel {
|
||||
KernelStringJoin(OrtApi api);
|
||||
KernelStringJoin(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringLength::KernelStringLength(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringLength::Compute(OrtKernelContext* context) {
|
||||
|
@ -27,7 +27,7 @@ void KernelStringLength::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpStringLength::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringLength::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringLength(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLength : BaseKernel {
|
||||
KernelStringLength(OrtApi api);
|
||||
KernelStringLength(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringLower::KernelStringLower(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringLower::Compute(OrtKernelContext* context) {
|
||||
|
@ -25,7 +25,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringLower::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringLower::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringLower(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLower : BaseKernel {
|
||||
KernelStringLower(OrtApi api);
|
||||
KernelStringLower(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringMapping::KernelStringMapping(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api) {
|
||||
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
auto lines = SplitString(map, "\n", true);
|
||||
for (const auto& line: lines) {
|
||||
|
@ -41,7 +41,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, input_data, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringMapping::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringMapping(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,14 +8,14 @@
|
|||
#include <unordered_map>
|
||||
|
||||
struct KernelStringMapping : BaseKernel {
|
||||
KernelStringMapping(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "string_split.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
|
@ -96,7 +96,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, words, out_text);
|
||||
}
|
||||
|
||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringSplit::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringSplit(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
KernelStringSplit(OrtApi api);
|
||||
KernelStringSplit(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -100,7 +100,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
}
|
||||
}
|
||||
|
||||
KernelStringToVector::KernelStringToVector(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||
|
@ -132,7 +132,7 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpStringToVector::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelStringToVector(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class StringToVectorImpl {
|
|||
};
|
||||
|
||||
struct KernelStringToVector : BaseKernel {
|
||||
KernelStringToVector(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -37,7 +37,7 @@ struct KernelStringToVector : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) {
|
||||
KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||
|
@ -26,7 +26,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
void* CustomOpStringUpper::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringUpper(api);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringUpper : BaseKernel {
|
||||
KernelStringUpper(OrtApi api);
|
||||
KernelStringUpper(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -103,7 +103,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
}
|
||||
}
|
||||
|
||||
KernelVectorToString::KernelVectorToString(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||
|
||||
|
@ -124,7 +124,7 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, mapping_result, output);
|
||||
}
|
||||
|
||||
void* CustomOpVectorToString::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelVectorToString(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class VectorToStringImpl {
|
|||
};
|
||||
|
||||
struct KernelVectorToString : BaseKernel {
|
||||
KernelVectorToString(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -41,7 +41,7 @@ struct KernelVectorToString : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -78,7 +78,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
return result;
|
||||
}
|
||||
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
|
@ -105,7 +105,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelBasicTokenizer(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -21,14 +21,14 @@ class BasicTokenizer {
|
|||
};
|
||||
|
||||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -259,7 +259,7 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
|||
}
|
||||
}
|
||||
|
||||
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||
|
@ -317,7 +317,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
SetOutput(context, 2, output_dim, attention_mask);
|
||||
}
|
||||
|
||||
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelBertTokenizer(api, info);
|
||||
}
|
||||
|
||||
|
@ -339,7 +339,7 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
KernelHfBertTokenizer::KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||
|
||||
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
|
@ -373,7 +373,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
SetOutput(context, 2, outer_dims, token_type_ids);
|
||||
}
|
||||
|
||||
void* CustomOpHfBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelHfBertTokenizer(api, info);
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class BertTokenizer final {
|
|||
};
|
||||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -98,7 +98,7 @@ struct KernelBertTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
@ -107,12 +107,12 @@ struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBe
|
|||
};
|
||||
|
||||
struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||
KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -112,7 +112,7 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
|
|||
return false;
|
||||
}
|
||||
|
||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||
|
@ -170,7 +170,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpBertTokenizerDecoder::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelBertTokenizerDecoder(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class BertTokenizerDecoder {
|
|||
};
|
||||
|
||||
struct KernelBertTokenizerDecoder : BaseKernel {
|
||||
KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||
|
@ -43,7 +43,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
if (model_data_.empty()) {
|
||||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
|
@ -78,7 +78,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||
}
|
||||
|
||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelBlingFireSentenceBreaker(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ extern "C" int FreeModel(void* ModelPtr);
|
|||
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
||||
|
||||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
using ModelPtr = std::shared_ptr<void>;
|
||||
|
@ -26,7 +26,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -463,7 +463,7 @@ bool IsEmptyUString(const ustring& str) {
|
|||
|
||||
|
||||
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info)
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||
if (vocab.empty()) {
|
||||
|
@ -582,7 +582,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpBpeTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelBpeTokenizer(api, info);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
class VocabData;
|
||||
|
||||
struct KernelBpeTokenizer : BaseKernel {
|
||||
KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -18,7 +18,7 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_tensor.h"
|
||||
#include "base64.h"
|
||||
|
||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
sentencepiece::ModelProto model_proto;
|
||||
std::vector<uint8_t> model_as_bytes;
|
||||
|
@ -100,7 +100,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpSentencepieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelSentencepieceTokenizer(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "sentencepiece_processor.h"
|
||||
|
||||
struct KernelSentencepieceTokenizer : BaseKernel {
|
||||
KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
|
||||
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
|
||||
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||
|
@ -162,7 +162,7 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
ptr_row_lengths[i] = row_begins[i];
|
||||
}
|
||||
|
||||
void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelWordpieceTokenizer(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
#include "string_tensor.h"
|
||||
|
||||
struct KernelWordpieceTokenizer : BaseKernel {
|
||||
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -22,7 +22,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
|||
}
|
||||
|
||||
static py::object BuildPyObjFromTensor(
|
||||
OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
||||
const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
||||
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||
std::vector<npy_intp> npy_dims;
|
||||
for (auto n : shape) {
|
||||
|
@ -210,7 +210,7 @@ typedef struct {
|
|||
std::vector<int64_t> dimensions;
|
||||
} InputInformation;
|
||||
|
||||
PyCustomOpKernel::PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info,
|
||||
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
|
||||
uint64_t id, const std::vector<std::string>& attrs)
|
||||
: api_(api),
|
||||
ort_(api_),
|
||||
|
@ -377,7 +377,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
|||
// No need to protect against concurrent access, GIL is doing that.
|
||||
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
|
||||
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
||||
assert(success || it_domain_op->second.length() > 0);
|
||||
assert(success || !it_domain_op->second.empty());
|
||||
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
|
||||
}
|
||||
|
||||
|
|
|
@ -37,11 +37,11 @@ struct PyCustomOpDef {
|
|||
};
|
||||
|
||||
struct PyCustomOpKernel {
|
||||
PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
|
||||
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
const OrtApi& api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
uint64_t obj_id_;
|
||||
std::map<std::string, std::string> attrs_values_;
|
||||
|
@ -60,7 +60,7 @@ struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel
|
|||
op_type_ = op;
|
||||
}
|
||||
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
|
||||
};
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ const char* GetLibraryPath() {
|
|||
}
|
||||
|
||||
struct KernelOne : BaseKernel {
|
||||
KernelOne(OrtApi api) : BaseKernel(api) {
|
||||
KernelOne(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -52,7 +52,7 @@ struct KernelOne : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelOne(api);
|
||||
};
|
||||
const char* GetName() const {
|
||||
|
@ -73,7 +73,7 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
|||
};
|
||||
|
||||
struct KernelTwo : BaseKernel {
|
||||
KernelTwo(OrtApi api) : BaseKernel(api) {
|
||||
KernelTwo(const OrtApi& api) : BaseKernel(api) {
|
||||
}
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
|
@ -98,7 +98,7 @@ struct KernelTwo : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelTwo(api);
|
||||
};
|
||||
const char* GetName() const {
|
||||
|
@ -119,7 +119,7 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|||
};
|
||||
|
||||
struct KernelThree : BaseKernel {
|
||||
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
if (!TryToGetAttribute("substr", substr_)) {
|
||||
substr_ = "";
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ struct KernelThree : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelThree(api, info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
|
|
Загрузка…
Ссылка в новой задаче