Родитель
658aa28572
Коммит
ae416f6aa6
|
@ -18,8 +18,8 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
|
||||||
const char c_OpDomain[] = "ai.onnx.contrib";
|
const char c_OpDomain[] = "ai.onnx.contrib";
|
||||||
|
|
||||||
struct BaseKernel {
|
struct BaseKernel {
|
||||||
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
|
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||||
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
||||||
|
|
||||||
bool HasAttribute(const char* name) const;
|
bool HasAttribute(const char* name) const;
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ struct BaseKernel {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
const OrtApi& api_;
|
||||||
Ort::CustomOpApi ort_;
|
Ort::CustomOpApi ort_;
|
||||||
const OrtKernelInfo* info_;
|
const OrtKernelInfo* info_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
|
|
||||||
struct KernelGaussianBlur : BaseKernel {
|
struct KernelGaussianBlur : BaseKernel {
|
||||||
KernelGaussianBlur(OrtApi api) : BaseKernel(api) {
|
KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context) {
|
void Compute(OrtKernelContext* context) {
|
||||||
|
@ -78,7 +78,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
|
||||||
return "GaussianBlur";
|
return "GaussianBlur";
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelGaussianBlur(api);
|
return new KernelGaussianBlur(api);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
|
|
||||||
struct KernelImageReader : BaseKernel {
|
struct KernelImageReader : BaseKernel {
|
||||||
KernelImageReader(OrtApi api) : BaseKernel(api) {
|
KernelImageReader(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context) {
|
void Compute(OrtKernelContext* context) {
|
||||||
|
@ -49,7 +49,7 @@ struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageR
|
||||||
return "ImageReader";
|
return "ImageReader";
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelImageReader(api);
|
return new KernelImageReader(api);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
|
|
||||||
struct KernelInverse : BaseKernel {
|
struct KernelInverse : BaseKernel {
|
||||||
KernelInverse(OrtApi api) : BaseKernel(api) {
|
KernelInverse(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context) {
|
void Compute(OrtKernelContext* context) {
|
||||||
|
@ -34,7 +34,7 @@ struct KernelInverse : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelInverse(api);
|
return new KernelInverse(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
struct KernelNegPos : BaseKernel {
|
struct KernelNegPos : BaseKernel {
|
||||||
KernelNegPos(OrtApi api) : BaseKernel(api) {
|
KernelNegPos(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context){
|
void Compute(OrtKernelContext* context){
|
||||||
|
@ -40,7 +40,7 @@ struct KernelNegPos : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const{
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelNegPos(api);
|
return new KernelNegPos(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
#include "segment_extraction.hpp"
|
#include "segment_extraction.hpp"
|
||||||
|
|
||||||
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
|
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
||||||
|
@ -51,7 +51,7 @@ ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*inde
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpSegmentExtraction::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpSegmentExtraction::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelSegmentExtraction(api);
|
return new KernelSegmentExtraction(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelSegmentExtraction : BaseKernel {
|
struct KernelSegmentExtraction : BaseKernel {
|
||||||
KernelSegmentExtraction(OrtApi api);
|
KernelSegmentExtraction(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -16,6 +16,6 @@ struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction,
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -52,7 +52,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) {
|
KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelSegmentSum::Compute(OrtKernelContext* context) {
|
void KernelSegmentSum::Compute(OrtKernelContext* context) {
|
||||||
|
@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) co
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpSegmentSum::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelSegmentSum(api);
|
return new KernelSegmentSum(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelSegmentSum : BaseKernel {
|
struct KernelSegmentSum : BaseKernel {
|
||||||
KernelSegmentSum(OrtApi api);
|
KernelSegmentSum(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -16,6 +16,6 @@ struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentS
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
KernelMaskedFill::KernelMaskedFill(OrtApi api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
|
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||||
|
@ -51,7 +51,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, result, output);
|
FillTensorDataString(api_, ort_, context, result, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpMaskedFill::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelMaskedFill(api, info);
|
return new KernelMaskedFill(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,14 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
struct KernelMaskedFill : BaseKernel {
|
struct KernelMaskedFill : BaseKernel {
|
||||||
KernelMaskedFill(OrtApi api, const OrtKernelInfo* info);
|
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, std::string> map_;
|
std::unordered_map<std::string, std::string> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
#include "op_equal_impl.hpp"
|
#include "op_equal_impl.hpp"
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
KernelStringEqual::KernelStringEqual(OrtApi api) : BaseKernel(api) {
|
KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringEqual::Compute(OrtKernelContext* context) {
|
void KernelStringEqual::Compute(OrtKernelContext* context) {
|
||||||
|
@ -24,7 +24,7 @@ ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) c
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpStringEqual::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const{
|
void* CustomOpStringEqual::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const{
|
||||||
return new KernelStringEqual(api);
|
return new KernelStringEqual(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringEqual : BaseKernel {
|
struct KernelStringEqual : BaseKernel {
|
||||||
KernelStringEqual(OrtApi api);
|
KernelStringEqual(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -16,6 +16,6 @@ struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelString
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
#include "op_ragged_tensor.hpp"
|
#include "op_ragged_tensor.hpp"
|
||||||
|
|
||||||
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
|
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
||||||
|
@ -53,7 +53,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*i
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpRaggedTensorToSparse::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpRaggedTensorToSparse::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelRaggedTensorToSparse(api);
|
return new KernelRaggedTensorToSparse(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||||
|
@ -84,7 +84,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
|
||||||
return max_col;
|
return max_col;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||||
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
|
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelRaggedTensorToDense(api, info);
|
return new KernelRaggedTensorToDense(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||||
|
@ -193,7 +193,7 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* CustomOpStringRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringRaggedTensorToDense(api, info);
|
return new KernelStringRaggedTensorToDense(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||||
KernelRaggedTensorToSparse(OrtApi api);
|
KernelRaggedTensorToSparse(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -15,12 +15,12 @@ struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSp
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CommonRaggedTensorToDense : BaseKernel {
|
struct CommonRaggedTensorToDense : BaseKernel {
|
||||||
CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
||||||
|
@ -28,7 +28,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||||
KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -40,12 +40,12 @@ struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDen
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||||
KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
|
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -54,6 +54,6 @@ struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRagge
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#include "re2/re2.h"
|
#include "re2/re2.h"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
|
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, str_input, output);
|
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringRegexReplace(api, info);
|
return new KernelStringRegexReplace(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringRegexReplace : BaseKernel {
|
struct KernelStringRegexReplace : BaseKernel {
|
||||||
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info);
|
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -15,7 +15,7 @@ struct KernelStringRegexReplace : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||||
|
@ -79,7 +79,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringRegexSplitWithOffsets(api, info);
|
return new KernelStringRegexSplitWithOffsets(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,12 @@
|
||||||
|
|
||||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||||
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
||||||
KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
|
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) {
|
KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringConcat::Compute(OrtKernelContext* context) {
|
void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||||
|
@ -37,7 +37,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, left_value, output);
|
FillTensorDataString(api_, ort_, context, left_value, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringConcat::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringConcat(api);
|
return new KernelStringConcat(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringConcat : BaseKernel {
|
struct KernelStringConcat : BaseKernel {
|
||||||
KernelStringConcat(OrtApi api);
|
KernelStringConcat(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
KernelStringECMARegexReplace::KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
|
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
|
||||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, str_input, output);
|
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringECMARegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringECMARegexReplace(api, info);
|
return new KernelStringECMARegexReplace(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringECMARegexReplace : BaseKernel {
|
struct KernelStringECMARegexReplace : BaseKernel {
|
||||||
KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info);
|
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -16,7 +16,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
|
|
||||||
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringECMARegexSplitWithOffsets(api, info);
|
return new KernelStringECMARegexSplitWithOffsets(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -9,14 +9,14 @@
|
||||||
|
|
||||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||||
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
||||||
KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
|
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
bool ignore_case_;
|
bool ignore_case_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include "string_hash.hpp"
|
#include "string_hash.hpp"
|
||||||
|
|
||||||
|
|
||||||
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
|
KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringHash::Compute(OrtKernelContext* context) {
|
void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||||
|
@ -43,7 +43,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringHash::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringHash::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringHash(api);
|
return new KernelStringHash(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) {
|
KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||||
|
@ -106,7 +106,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringHashFast::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringHashFast::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringHashFast(api);
|
return new KernelStringHashFast(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringHash : BaseKernel {
|
struct KernelStringHash : BaseKernel {
|
||||||
KernelStringHash(OrtApi api);
|
KernelStringHash(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
@ -21,12 +21,12 @@ struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHa
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelStringHashFast : BaseKernel {
|
struct KernelStringHashFast : BaseKernel {
|
||||||
KernelStringHashFast(OrtApi api);
|
KernelStringHashFast(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "string_join.hpp"
|
#include "string_join.hpp"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
|
KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||||
|
@ -86,7 +86,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, out, output);
|
FillTensorDataString(api_, ort_, context, out, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringJoin::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringJoin(api);
|
return new KernelStringJoin(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringJoin : BaseKernel {
|
struct KernelStringJoin : BaseKernel {
|
||||||
KernelStringJoin(OrtApi api);
|
KernelStringJoin(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
KernelStringLength::KernelStringLength(OrtApi api) : BaseKernel(api) {
|
KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringLength::Compute(OrtKernelContext* context) {
|
void KernelStringLength::Compute(OrtKernelContext* context) {
|
||||||
|
@ -27,7 +27,7 @@ void KernelStringLength::Compute(OrtKernelContext* context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringLength::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringLength::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringLength(api);
|
return new KernelStringLength(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringLength : BaseKernel {
|
struct KernelStringLength : BaseKernel {
|
||||||
KernelStringLength(OrtApi api);
|
KernelStringLength(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
KernelStringLower::KernelStringLower(OrtApi api) : BaseKernel(api) {
|
KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringLower::Compute(OrtKernelContext* context) {
|
void KernelStringLower::Compute(OrtKernelContext* context) {
|
||||||
|
@ -25,7 +25,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, X, output);
|
FillTensorDataString(api_, ort_, context, X, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringLower::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringLower::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringLower(api);
|
return new KernelStringLower(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringLower : BaseKernel {
|
struct KernelStringLower : BaseKernel {
|
||||||
KernelStringLower(OrtApi api);
|
KernelStringLower(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
KernelStringMapping::KernelStringMapping(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api) {
|
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
|
||||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||||
auto lines = SplitString(map, "\n", true);
|
auto lines = SplitString(map, "\n", true);
|
||||||
for (const auto& line: lines) {
|
for (const auto& line: lines) {
|
||||||
|
@ -41,7 +41,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, input_data, output);
|
FillTensorDataString(api_, ort_, context, input_data, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringMapping::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringMapping(api, info);
|
return new KernelStringMapping(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,14 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
struct KernelStringMapping : BaseKernel {
|
struct KernelStringMapping : BaseKernel {
|
||||||
KernelStringMapping(OrtApi api, const OrtKernelInfo* info);
|
KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, std::string> map_;
|
std::unordered_map<std::string, std::string> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "string_split.hpp"
|
#include "string_split.hpp"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||||
|
@ -96,7 +96,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, words, out_text);
|
FillTensorDataString(api_, ort_, context, words, out_text);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringSplit::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringSplit(api);
|
return new KernelStringSplit(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringSplit : BaseKernel {
|
struct KernelStringSplit : BaseKernel {
|
||||||
KernelStringSplit(OrtApi api);
|
KernelStringSplit(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -100,7 +100,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelStringToVector::KernelStringToVector(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||||
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
|
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
|
||||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||||
|
@ -132,7 +132,7 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringToVector::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelStringToVector(api, info);
|
return new KernelStringToVector(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ class StringToVectorImpl {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelStringToVector : BaseKernel {
|
struct KernelStringToVector : BaseKernel {
|
||||||
KernelStringToVector(OrtApi api, const OrtKernelInfo* info);
|
KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -37,7 +37,7 @@ struct KernelStringToVector : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) {
|
KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||||
|
@ -26,7 +26,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, X, output);
|
FillTensorDataString(api_, ort_, context, X, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
void* CustomOpStringUpper::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
|
||||||
return new KernelStringUpper(api);
|
return new KernelStringUpper(api);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
|
|
||||||
struct KernelStringUpper : BaseKernel {
|
struct KernelStringUpper : BaseKernel {
|
||||||
KernelStringUpper(OrtApi api);
|
KernelStringUpper(const OrtApi& api);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -103,7 +103,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelVectorToString::KernelVectorToString(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, mapping_result, output);
|
FillTensorDataString(api_, ort_, context, mapping_result, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpVectorToString::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelVectorToString(api, info);
|
return new KernelVectorToString(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class VectorToStringImpl {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelVectorToString : BaseKernel {
|
struct KernelVectorToString : BaseKernel {
|
||||||
KernelVectorToString(OrtApi api, const OrtKernelInfo* info);
|
KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -41,7 +41,7 @@ struct KernelVectorToString : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -78,7 +78,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||||
|
@ -105,7 +105,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, result, output);
|
FillTensorDataString(api_, ort_, context, result, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelBasicTokenizer(api, info);
|
return new KernelBasicTokenizer(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -21,14 +21,14 @@ class BasicTokenizer {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelBasicTokenizer : BaseKernel {
|
struct KernelBasicTokenizer : BaseKernel {
|
||||||
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -259,7 +259,7 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||||
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||||
|
@ -317,7 +317,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||||
SetOutput(context, 2, output_dim, attention_mask);
|
SetOutput(context, 2, output_dim, attention_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelBertTokenizer(api, info);
|
return new KernelBertTokenizer(api, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,7 +339,7 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelHfBertTokenizer::KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||||
|
|
||||||
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||||
// Setup inputs
|
// Setup inputs
|
||||||
|
@ -373,7 +373,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||||
SetOutput(context, 2, outer_dims, token_type_ids);
|
SetOutput(context, 2, outer_dims, token_type_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpHfBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelHfBertTokenizer(api, info);
|
return new KernelHfBertTokenizer(api, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class BertTokenizer final {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelBertTokenizer : BaseKernel {
|
struct KernelBertTokenizer : BaseKernel {
|
||||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -98,7 +98,7 @@ struct KernelBertTokenizer : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
@ -107,12 +107,12 @@ struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBe
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelHfBertTokenizer : KernelBertTokenizer {
|
struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||||
KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -112,7 +112,7 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||||
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||||
|
@ -170,7 +170,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
||||||
FillTensorDataString(api_, ort_, context, result, output);
|
FillTensorDataString(api_, ort_, context, result, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBertTokenizerDecoder::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelBertTokenizerDecoder(api, info);
|
return new KernelBertTokenizerDecoder(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class BertTokenizerDecoder {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelBertTokenizerDecoder : BaseKernel {
|
struct KernelBertTokenizerDecoder : BaseKernel {
|
||||||
KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info);
|
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||||
|
@ -43,7 +43,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||||
if (model_data_.empty()) {
|
if (model_data_.empty()) {
|
||||||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||||
|
@ -78,7 +78,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelBlingFireSentenceBreaker(api, info);
|
return new KernelBlingFireSentenceBreaker(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ extern "C" int FreeModel(void* ModelPtr);
|
||||||
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
||||||
|
|
||||||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||||
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info);
|
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
private:
|
private:
|
||||||
using ModelPtr = std::shared_ptr<void>;
|
using ModelPtr = std::shared_ptr<void>;
|
||||||
|
@ -26,7 +26,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -463,7 +463,7 @@ bool IsEmptyUString(const ustring& str) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info)
|
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
|
||||||
: BaseKernel(api, info) {
|
: BaseKernel(api, info) {
|
||||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||||
if (vocab.empty()) {
|
if (vocab.empty()) {
|
||||||
|
@ -582,7 +582,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBpeTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelBpeTokenizer(api, info);
|
return new KernelBpeTokenizer(api, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
class VocabData;
|
class VocabData;
|
||||||
|
|
||||||
struct KernelBpeTokenizer : BaseKernel {
|
struct KernelBpeTokenizer : BaseKernel {
|
||||||
KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -18,7 +18,7 @@ struct KernelBpeTokenizer : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
#include "base64.h"
|
#include "base64.h"
|
||||||
|
|
||||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||||
sentencepiece::ModelProto model_proto;
|
sentencepiece::ModelProto model_proto;
|
||||||
std::vector<uint8_t> model_as_bytes;
|
std::vector<uint8_t> model_as_bytes;
|
||||||
|
@ -100,7 +100,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
||||||
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
|
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpSentencepieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelSentencepieceTokenizer(api, info);
|
return new KernelSentencepieceTokenizer(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#include "sentencepiece_processor.h"
|
#include "sentencepiece_processor.h"
|
||||||
|
|
||||||
struct KernelSentencepieceTokenizer : BaseKernel {
|
struct KernelSentencepieceTokenizer : BaseKernel {
|
||||||
KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "wordpiece_tokenizer.hpp"
|
#include "wordpiece_tokenizer.hpp"
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
|
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
|
||||||
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
|
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
|
||||||
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||||
|
@ -162,7 +162,7 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
||||||
ptr_row_lengths[i] = row_begins[i];
|
ptr_row_lengths[i] = row_begins[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelWordpieceTokenizer(api, info);
|
return new KernelWordpieceTokenizer(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
struct KernelWordpieceTokenizer : BaseKernel {
|
struct KernelWordpieceTokenizer : BaseKernel {
|
||||||
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info);
|
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -22,7 +22,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
||||||
}
|
}
|
||||||
|
|
||||||
static py::object BuildPyObjFromTensor(
|
static py::object BuildPyObjFromTensor(
|
||||||
OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
||||||
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||||
std::vector<npy_intp> npy_dims;
|
std::vector<npy_intp> npy_dims;
|
||||||
for (auto n : shape) {
|
for (auto n : shape) {
|
||||||
|
@ -210,7 +210,7 @@ typedef struct {
|
||||||
std::vector<int64_t> dimensions;
|
std::vector<int64_t> dimensions;
|
||||||
} InputInformation;
|
} InputInformation;
|
||||||
|
|
||||||
PyCustomOpKernel::PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info,
|
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
|
||||||
uint64_t id, const std::vector<std::string>& attrs)
|
uint64_t id, const std::vector<std::string>& attrs)
|
||||||
: api_(api),
|
: api_(api),
|
||||||
ort_(api_),
|
ort_(api_),
|
||||||
|
@ -377,7 +377,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
||||||
// No need to protect against concurrent access, GIL is doing that.
|
// No need to protect against concurrent access, GIL is doing that.
|
||||||
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
|
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
|
||||||
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
||||||
assert(success || it_domain_op->second.length() > 0);
|
assert(success || !it_domain_op->second.empty());
|
||||||
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
|
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,11 +37,11 @@ struct PyCustomOpDef {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PyCustomOpKernel {
|
struct PyCustomOpKernel {
|
||||||
PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
|
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
const OrtApi& api_;
|
||||||
Ort::CustomOpApi ort_;
|
Ort::CustomOpApi ort_;
|
||||||
uint64_t obj_id_;
|
uint64_t obj_id_;
|
||||||
std::map<std::string, std::string> attrs_values_;
|
std::map<std::string, std::string> attrs_values_;
|
||||||
|
@ -60,7 +60,7 @@ struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel
|
||||||
op_type_ = op;
|
op_type_ = op;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
|
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ const char* GetLibraryPath() {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct KernelOne : BaseKernel {
|
struct KernelOne : BaseKernel {
|
||||||
KernelOne(OrtApi api) : BaseKernel(api) {
|
KernelOne(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context) {
|
void Compute(OrtKernelContext* context) {
|
||||||
|
@ -52,7 +52,7 @@ struct KernelOne : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelOne(api);
|
return new KernelOne(api);
|
||||||
};
|
};
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
|
@ -73,7 +73,7 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelTwo : BaseKernel {
|
struct KernelTwo : BaseKernel {
|
||||||
KernelTwo(OrtApi api) : BaseKernel(api) {
|
KernelTwo(const OrtApi& api) : BaseKernel(api) {
|
||||||
}
|
}
|
||||||
void Compute(OrtKernelContext* context) {
|
void Compute(OrtKernelContext* context) {
|
||||||
// Setup inputs
|
// Setup inputs
|
||||||
|
@ -98,7 +98,7 @@ struct KernelTwo : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelTwo(api);
|
return new KernelTwo(api);
|
||||||
};
|
};
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
|
@ -119,7 +119,7 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KernelThree : BaseKernel {
|
struct KernelThree : BaseKernel {
|
||||||
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||||
if (!TryToGetAttribute("substr", substr_)) {
|
if (!TryToGetAttribute("substr", substr_)) {
|
||||||
substr_ = "";
|
substr_ = "";
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ struct KernelThree : BaseKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelThree(api, info);
|
return new KernelThree(api, info);
|
||||||
};
|
};
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче