Issue #288: Prevent copying of OrtApi struct (#290)

This commit is contained in:
Adrian Lizarraga 2022-09-14 09:45:33 -07:00 коммит произвёл GitHub
Родитель 658aa28572
Коммит ae416f6aa6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
60 изменённых файлов: 145 добавлений и 145 удалений

Просмотреть файл

@ -18,8 +18,8 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
const char c_OpDomain[] = "ai.onnx.contrib"; const char c_OpDomain[] = "ai.onnx.contrib";
struct BaseKernel { struct BaseKernel {
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {} BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {} BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
bool HasAttribute(const char* name) const; bool HasAttribute(const char* name) const;
@ -37,7 +37,7 @@ struct BaseKernel {
protected: protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status); OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_ const OrtApi& api_;
Ort::CustomOpApi ort_; Ort::CustomOpApi ort_;
const OrtKernelInfo* info_; const OrtKernelInfo* info_;
}; };

Просмотреть файл

@ -3,7 +3,7 @@
struct KernelGaussianBlur : BaseKernel { struct KernelGaussianBlur : BaseKernel {
KernelGaussianBlur(OrtApi api) : BaseKernel(api) { KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context) { void Compute(OrtKernelContext* context) {
@ -78,7 +78,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
return "GaussianBlur"; return "GaussianBlur";
} }
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelGaussianBlur(api); return new KernelGaussianBlur(api);
} }
}; };

Просмотреть файл

@ -5,7 +5,7 @@
struct KernelImageReader : BaseKernel { struct KernelImageReader : BaseKernel {
KernelImageReader(OrtApi api) : BaseKernel(api) { KernelImageReader(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context) { void Compute(OrtKernelContext* context) {
@ -49,7 +49,7 @@ struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageR
return "ImageReader"; return "ImageReader";
} }
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelImageReader(api); return new KernelImageReader(api);
} }
}; };

Просмотреть файл

@ -8,7 +8,7 @@
struct KernelInverse : BaseKernel { struct KernelInverse : BaseKernel {
KernelInverse(OrtApi api) : BaseKernel(api) { KernelInverse(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context) { void Compute(OrtKernelContext* context) {
@ -34,7 +34,7 @@ struct KernelInverse : BaseKernel {
}; };
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> { struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelInverse(api); return new KernelInverse(api);
} }

Просмотреть файл

@ -6,7 +6,7 @@
#include "ocos.h" #include "ocos.h"
struct KernelNegPos : BaseKernel { struct KernelNegPos : BaseKernel {
KernelNegPos(OrtApi api) : BaseKernel(api) { KernelNegPos(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context){ void Compute(OrtKernelContext* context){
@ -40,7 +40,7 @@ struct KernelNegPos : BaseKernel {
}; };
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> { struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const{ void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelNegPos(api); return new KernelNegPos(api);
} }

Просмотреть файл

@ -3,7 +3,7 @@
#include "segment_extraction.hpp" #include "segment_extraction.hpp"
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) { KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
} }
void KernelSegmentExtraction::Compute(OrtKernelContext* context) { void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
@ -51,7 +51,7 @@ ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*inde
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
void* CustomOpSegmentExtraction::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpSegmentExtraction::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelSegmentExtraction(api); return new KernelSegmentExtraction(api);
}; };

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelSegmentExtraction : BaseKernel { struct KernelSegmentExtraction : BaseKernel {
KernelSegmentExtraction(OrtApi api); KernelSegmentExtraction(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
@ -16,6 +16,6 @@ struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction,
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };

Просмотреть файл

@ -52,7 +52,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
} }
} }
KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) { KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
} }
void KernelSegmentSum::Compute(OrtKernelContext* context) { void KernelSegmentSum::Compute(OrtKernelContext* context) {
@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) co
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}; };
void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpSegmentSum::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelSegmentSum(api); return new KernelSegmentSum(api);
}; };

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelSegmentSum : BaseKernel { struct KernelSegmentSum : BaseKernel {
KernelSegmentSum(OrtApi api); KernelSegmentSum(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
@ -16,6 +16,6 @@ struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentS
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm> #include <algorithm>
KernelMaskedFill::KernelMaskedFill(OrtApi api, const OrtKernelInfo* /*info*/) : BaseKernel(api) { KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
} }
void KernelMaskedFill::Compute(OrtKernelContext* context) { void KernelMaskedFill::Compute(OrtKernelContext* context) {
@ -51,7 +51,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output); FillTensorDataString(api_, ort_, context, result, output);
} }
void* CustomOpMaskedFill::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelMaskedFill(api, info); return new KernelMaskedFill(api, info);
}; };

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map> #include <unordered_map>
struct KernelMaskedFill : BaseKernel { struct KernelMaskedFill : BaseKernel {
KernelMaskedFill(OrtApi api, const OrtKernelInfo* info); KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
std::unordered_map<std::string, std::string> map_; std::unordered_map<std::string, std::string> map_;
}; };
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> { struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -5,7 +5,7 @@
#include "op_equal_impl.hpp" #include "op_equal_impl.hpp"
#include <string> #include <string>
KernelStringEqual::KernelStringEqual(OrtApi api) : BaseKernel(api) { KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringEqual::Compute(OrtKernelContext* context) { void KernelStringEqual::Compute(OrtKernelContext* context) {
@ -24,7 +24,7 @@ ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) c
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
}; };
void* CustomOpStringEqual::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const{ void* CustomOpStringEqual::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const{
return new KernelStringEqual(api); return new KernelStringEqual(api);
}; };

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringEqual : BaseKernel { struct KernelStringEqual : BaseKernel {
KernelStringEqual(OrtApi api); KernelStringEqual(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
@ -16,6 +16,6 @@ struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelString
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_tensor.h" #include "string_tensor.h"
#include "op_ragged_tensor.hpp" #include "op_ragged_tensor.hpp"
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) { KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(const OrtApi& api) : BaseKernel(api) {
} }
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) { void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
@ -53,7 +53,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*i
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
void* CustomOpRaggedTensorToSparse::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpRaggedTensorToSparse::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelRaggedTensorToSparse(api); return new KernelRaggedTensorToSparse(api);
}; };
@ -65,7 +65,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
CommonRaggedTensorToDense::CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
} }
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) { void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
@ -84,7 +84,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
return max_col; return max_col;
} }
KernelRaggedTensorToDense::KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) { KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1; missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
} }
@ -134,7 +134,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
void* CustomOpRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelRaggedTensorToDense(api, info); return new KernelRaggedTensorToDense(api, info);
}; };
@ -146,7 +146,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) { KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
} }
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) { void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
@ -193,7 +193,7 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}; };
void* CustomOpStringRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRaggedTensorToDense(api, info); return new KernelStringRaggedTensorToDense(api, info);
}; };

Просмотреть файл

@ -6,7 +6,7 @@
#include "ocos.h" #include "ocos.h"
struct KernelRaggedTensorToSparse : BaseKernel { struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(OrtApi api); KernelRaggedTensorToSparse(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
@ -15,12 +15,12 @@ struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSp
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };
struct CommonRaggedTensorToDense : BaseKernel { struct CommonRaggedTensorToDense : BaseKernel {
CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info); CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
protected: protected:
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims); void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
@ -28,7 +28,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
}; };
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense { struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info); KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -40,12 +40,12 @@ struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDen
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense { struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info); KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
@ -54,6 +54,6 @@ struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRagge
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const; const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
}; };

Просмотреть файл

@ -8,7 +8,7 @@
#include "re2/re2.h" #include "re2/re2.h"
#include "string_tensor.h" #include "string_tensor.h"
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1; global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
} }
@ -62,7 +62,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output); FillTensorDataString(api_, ort_, context, str_input, output);
} }
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRegexReplace(api, info); return new KernelStringRegexReplace(api, info);
}; };

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringRegexReplace : BaseKernel { struct KernelStringRegexReplace : BaseKernel {
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info); KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
protected: protected:
@ -15,7 +15,7 @@ struct KernelStringRegexReplace : BaseKernel {
}; };
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> { struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
} }
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) { void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
@ -79,7 +79,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t)); memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
} }
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRegexSplitWithOffsets(api, info); return new KernelStringRegexSplitWithOffsets(api, info);
}; };

Просмотреть файл

@ -8,12 +8,12 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md. // See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringRegexSplitWithOffsets : BaseKernel { struct KernelStringRegexSplitWithOffsets : BaseKernel {
KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info); KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> { struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm> #include <algorithm>
KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) { KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringConcat::Compute(OrtKernelContext* context) { void KernelStringConcat::Compute(OrtKernelContext* context) {
@ -37,7 +37,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, left_value, output); FillTensorDataString(api_, ort_, context, left_value, output);
} }
void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringConcat::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringConcat(api); return new KernelStringConcat(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringConcat : BaseKernel { struct KernelStringConcat : BaseKernel {
KernelStringConcat(OrtApi api); KernelStringConcat(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> { struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <regex> #include <regex>
#include "string_tensor.h" #include "string_tensor.h"
KernelStringECMARegexReplace::KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = TryToGetAttributeWithDefault("global_replace", true); global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false); ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
@ -70,7 +70,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output); FillTensorDataString(api_, ort_, context, str_input, output);
} }
void* CustomOpStringECMARegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringECMARegexReplace(api, info); return new KernelStringECMARegexReplace(api, info);
}; };

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringECMARegexReplace : BaseKernel { struct KernelStringECMARegexReplace : BaseKernel {
KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info); KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
protected: protected:
@ -16,7 +16,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
}; };
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> { struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -10,7 +10,7 @@
#include "string_tensor.h" #include "string_tensor.h"
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false); ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
} }
@ -84,7 +84,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t)); memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
} }
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringECMARegexSplitWithOffsets(api, info); return new KernelStringECMARegexSplitWithOffsets(api, info);
}; };

Просмотреть файл

@ -9,14 +9,14 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md. // See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringECMARegexSplitWithOffsets : BaseKernel { struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info); KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
bool ignore_case_; bool ignore_case_;
}; };
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> { struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include "string_hash.hpp" #include "string_hash.hpp"
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) { KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringHash::Compute(OrtKernelContext* context) { void KernelStringHash::Compute(OrtKernelContext* context) {
@ -43,7 +43,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
} }
} }
void* CustomOpStringHash::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringHash::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringHash(api); return new KernelStringHash(api);
}; };
@ -72,7 +72,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}; };
KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) { KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringHashFast::Compute(OrtKernelContext* context) { void KernelStringHashFast::Compute(OrtKernelContext* context) {
@ -106,7 +106,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
} }
} }
void* CustomOpStringHashFast::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringHashFast::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringHashFast(api); return new KernelStringHashFast(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringHash : BaseKernel { struct KernelStringHash : BaseKernel {
KernelStringHash(OrtApi api); KernelStringHash(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> { struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
@ -21,12 +21,12 @@ struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHa
}; };
struct KernelStringHashFast : BaseKernel { struct KernelStringHashFast : BaseKernel {
KernelStringHashFast(OrtApi api); KernelStringHashFast(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> { struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_join.hpp" #include "string_join.hpp"
#include "string_tensor.h" #include "string_tensor.h"
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) { KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringJoin::Compute(OrtKernelContext* context) { void KernelStringJoin::Compute(OrtKernelContext* context) {
@ -86,7 +86,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, out, output); FillTensorDataString(api_, ort_, context, out, output);
} }
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringJoin::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringJoin(api); return new KernelStringJoin(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringJoin : BaseKernel { struct KernelStringJoin : BaseKernel {
KernelStringJoin(OrtApi api); KernelStringJoin(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> { struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm> #include <algorithm>
KernelStringLength::KernelStringLength(OrtApi api) : BaseKernel(api) { KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringLength::Compute(OrtKernelContext* context) { void KernelStringLength::Compute(OrtKernelContext* context) {
@ -27,7 +27,7 @@ void KernelStringLength::Compute(OrtKernelContext* context) {
} }
} }
void* CustomOpStringLength::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringLength::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringLength(api); return new KernelStringLength(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringLength : BaseKernel { struct KernelStringLength : BaseKernel {
KernelStringLength(OrtApi api); KernelStringLength(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> { struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
KernelStringLower::KernelStringLower(OrtApi api) : BaseKernel(api) { KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringLower::Compute(OrtKernelContext* context) { void KernelStringLower::Compute(OrtKernelContext* context) {
@ -25,7 +25,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, X, output); FillTensorDataString(api_, ort_, context, X, output);
} }
void* CustomOpStringLower::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringLower::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringLower(api); return new KernelStringLower(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringLower : BaseKernel { struct KernelStringLower : BaseKernel {
KernelStringLower(OrtApi api); KernelStringLower(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> { struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm> #include <algorithm>
KernelStringMapping::KernelStringMapping(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api) { KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map"); std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
auto lines = SplitString(map, "\n", true); auto lines = SplitString(map, "\n", true);
for (const auto& line: lines) { for (const auto& line: lines) {
@ -41,7 +41,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, input_data, output); FillTensorDataString(api_, ort_, context, input_data, output);
} }
void* CustomOpStringMapping::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringMapping(api, info); return new KernelStringMapping(api, info);
}; };

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map> #include <unordered_map>
struct KernelStringMapping : BaseKernel { struct KernelStringMapping : BaseKernel {
KernelStringMapping(OrtApi api, const OrtKernelInfo* info); KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
std::unordered_map<std::string, std::string> map_; std::unordered_map<std::string, std::string> map_;
}; };
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> { struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_split.hpp" #include "string_split.hpp"
#include "string_tensor.h" #include "string_tensor.h"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) { KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringSplit::Compute(OrtKernelContext* context) { void KernelStringSplit::Compute(OrtKernelContext* context) {
@ -96,7 +96,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, words, out_text); FillTensorDataString(api_, ort_, context, words, out_text);
} }
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringSplit::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringSplit(api); return new KernelStringSplit(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringSplit : BaseKernel { struct KernelStringSplit : BaseKernel {
KernelStringSplit(OrtApi api); KernelStringSplit(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> { struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -100,7 +100,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
} }
} }
KernelStringToVector::KernelStringToVector(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map"); std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector // unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk"); std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
@ -132,7 +132,7 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
} }
} }
void* CustomOpStringToVector::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringToVector(api, info); return new KernelStringToVector(api, info);
}; };

Просмотреть файл

@ -29,7 +29,7 @@ class StringToVectorImpl {
}; };
struct KernelStringToVector : BaseKernel { struct KernelStringToVector : BaseKernel {
KernelStringToVector(OrtApi api, const OrtKernelInfo* info); KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -37,7 +37,7 @@ struct KernelStringToVector : BaseKernel {
}; };
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> { struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) { KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
} }
void KernelStringUpper::Compute(OrtKernelContext* context) { void KernelStringUpper::Compute(OrtKernelContext* context) {
@ -26,7 +26,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, X, output); FillTensorDataString(api_, ort_, context, X, output);
} }
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { void* CustomOpStringUpper::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringUpper(api); return new KernelStringUpper(api);
}; };

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h" #include "string_utils.h"
struct KernelStringUpper : BaseKernel { struct KernelStringUpper : BaseKernel {
KernelStringUpper(OrtApi api); KernelStringUpper(const OrtApi& api);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> { struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -103,7 +103,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
} }
} }
KernelVectorToString::KernelVectorToString(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map"); std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk"); std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
@ -124,7 +124,7 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, mapping_result, output); FillTensorDataString(api_, ort_, context, mapping_result, output);
} }
void* CustomOpVectorToString::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelVectorToString(api, info); return new KernelVectorToString(api, info);
}; };

Просмотреть файл

@ -33,7 +33,7 @@ class VectorToStringImpl {
}; };
struct KernelVectorToString : BaseKernel { struct KernelVectorToString : BaseKernel {
KernelVectorToString(OrtApi api, const OrtKernelInfo* info); KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -41,7 +41,7 @@ struct KernelVectorToString : BaseKernel {
}; };
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> { struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -78,7 +78,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
return result; return result;
} }
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
@ -105,7 +105,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output); FillTensorDataString(api_, ort_, context, result, output);
} }
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBasicTokenizer(api, info); return new KernelBasicTokenizer(api, info);
}; };

Просмотреть файл

@ -21,14 +21,14 @@ class BasicTokenizer {
}; };
struct KernelBasicTokenizer : BaseKernel { struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info); KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
std::shared_ptr<BasicTokenizer> tokenizer_; std::shared_ptr<BasicTokenizer> tokenizer_;
}; };
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> { struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -259,7 +259,7 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
} }
} }
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file"); std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true); bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
@ -317,7 +317,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, output_dim, attention_mask); SetOutput(context, 2, output_dim, attention_mask);
} }
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBertTokenizer(api, info); return new KernelBertTokenizer(api, info);
} }
@ -339,7 +339,7 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} }
KernelHfBertTokenizer::KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {} KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) { void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs // Setup inputs
@ -373,7 +373,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, outer_dims, token_type_ids); SetOutput(context, 2, outer_dims, token_type_ids);
} }
void* CustomOpHfBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelHfBertTokenizer(api, info); return new KernelHfBertTokenizer(api, info);
} }

Просмотреть файл

@ -90,7 +90,7 @@ class BertTokenizer final {
}; };
struct KernelBertTokenizer : BaseKernel { struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info); KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
protected: protected:
@ -98,7 +98,7 @@ struct KernelBertTokenizer : BaseKernel {
}; };
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> { struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
@ -107,12 +107,12 @@ struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBe
}; };
struct KernelHfBertTokenizer : KernelBertTokenizer { struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info); KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> { struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -112,7 +112,7 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
return false; return false;
} }
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file"); std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
@ -170,7 +170,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output); FillTensorDataString(api_, ort_, context, result, output);
} }
void* CustomOpBertTokenizerDecoder::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBertTokenizerDecoder(api, info); return new KernelBertTokenizerDecoder(api, info);
}; };

Просмотреть файл

@ -33,7 +33,7 @@ class BertTokenizerDecoder {
}; };
struct KernelBertTokenizerDecoder : BaseKernel { struct KernelBertTokenizerDecoder : BaseKernel {
KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info); KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
std::shared_ptr<BertTokenizerDecoder> decoder_; std::shared_ptr<BertTokenizerDecoder> decoder_;
@ -43,7 +43,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
}; };
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> { struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) { KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model"); model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
if (model_data_.empty()) { if (model_data_.empty()) {
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
@ -78,7 +78,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size())); Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
} }
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBlingFireSentenceBreaker(api, info); return new KernelBlingFireSentenceBreaker(api, info);
}; };

Просмотреть файл

@ -16,7 +16,7 @@ extern "C" int FreeModel(void* ModelPtr);
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount); extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel { struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info); KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
using ModelPtr = std::shared_ptr<void>; using ModelPtr = std::shared_ptr<void>;
@ -26,7 +26,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
}; };
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> { struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -463,7 +463,7 @@ bool IsEmptyUString(const ustring& str) {
KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info) KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
: BaseKernel(api, info) { : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab"); std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
if (vocab.empty()) { if (vocab.empty()) {
@ -582,7 +582,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
} }
} }
void* CustomOpBpeTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBpeTokenizer(api, info); return new KernelBpeTokenizer(api, info);
} }

Просмотреть файл

@ -6,7 +6,7 @@
class VocabData; class VocabData;
struct KernelBpeTokenizer : BaseKernel { struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info); KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -18,7 +18,7 @@ struct KernelBpeTokenizer : BaseKernel {
}; };
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> { struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_tensor.h" #include "string_tensor.h"
#include "base64.h" #include "base64.h"
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model"); std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
sentencepiece::ModelProto model_proto; sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes; std::vector<uint8_t> model_as_bytes;
@ -100,7 +100,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t)); memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
} }
void* CustomOpSentencepieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelSentencepieceTokenizer(api, info); return new KernelSentencepieceTokenizer(api, info);
}; };

Просмотреть файл

@ -8,7 +8,7 @@
#include "sentencepiece_processor.h" #include "sentencepiece_processor.h"
struct KernelSentencepieceTokenizer : BaseKernel { struct KernelSentencepieceTokenizer : BaseKernel {
KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info); KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
}; };
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> { struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "wordpiece_tokenizer.hpp" #include "wordpiece_tokenizer.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md // https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py // https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab"); std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
@ -162,7 +162,7 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
ptr_row_lengths[i] = row_begins[i]; ptr_row_lengths[i] = row_begins[i];
} }
void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelWordpieceTokenizer(api, info); return new KernelWordpieceTokenizer(api, info);
}; };

Просмотреть файл

@ -11,7 +11,7 @@
#include "string_tensor.h" #include "string_tensor.h"
struct KernelWordpieceTokenizer : BaseKernel { struct KernelWordpieceTokenizer : BaseKernel {
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info); KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
@ -22,7 +22,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
}; };
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> { struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
} }
static py::object BuildPyObjFromTensor( static py::object BuildPyObjFromTensor(
OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value, const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
const shape_t& shape, ONNXTensorElementDataType dtype) { const shape_t& shape, ONNXTensorElementDataType dtype) {
std::vector<npy_intp> npy_dims; std::vector<npy_intp> npy_dims;
for (auto n : shape) { for (auto n : shape) {
@ -210,7 +210,7 @@ typedef struct {
std::vector<int64_t> dimensions; std::vector<int64_t> dimensions;
} InputInformation; } InputInformation;
PyCustomOpKernel::PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
uint64_t id, const std::vector<std::string>& attrs) uint64_t id, const std::vector<std::string>& attrs)
: api_(api), : api_(api),
ort_(api_), ort_(api_),
@ -377,7 +377,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
// No need to protect against concurrent access, GIL is doing that. // No need to protect against concurrent access, GIL is doing that.
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>()); auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
const auto [it_domain_op, success] = PyOp_container().insert(val); const auto [it_domain_op, success] = PyOp_container().insert(val);
assert(success || it_domain_op->second.length() > 0); assert(success || !it_domain_op->second.empty());
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op)); it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
} }

Просмотреть файл

@ -37,11 +37,11 @@ struct PyCustomOpDef {
}; };
struct PyCustomOpKernel { struct PyCustomOpKernel {
PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs); PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_ const OrtApi& api_;
Ort::CustomOpApi ort_; Ort::CustomOpApi ort_;
uint64_t obj_id_; uint64_t obj_id_;
std::map<std::string, std::string> attrs_values_; std::map<std::string, std::string> attrs_values_;
@ -60,7 +60,7 @@ struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel
op_type_ = op; op_type_ = op;
} }
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs); return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
}; };

Просмотреть файл

@ -24,7 +24,7 @@ const char* GetLibraryPath() {
} }
struct KernelOne : BaseKernel { struct KernelOne : BaseKernel {
KernelOne(OrtApi api) : BaseKernel(api) { KernelOne(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context) { void Compute(OrtKernelContext* context) {
@ -52,7 +52,7 @@ struct KernelOne : BaseKernel {
}; };
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> { struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelOne(api); return new KernelOne(api);
}; };
const char* GetName() const { const char* GetName() const {
@ -73,7 +73,7 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
}; };
struct KernelTwo : BaseKernel { struct KernelTwo : BaseKernel {
KernelTwo(OrtApi api) : BaseKernel(api) { KernelTwo(const OrtApi& api) : BaseKernel(api) {
} }
void Compute(OrtKernelContext* context) { void Compute(OrtKernelContext* context) {
// Setup inputs // Setup inputs
@ -98,7 +98,7 @@ struct KernelTwo : BaseKernel {
}; };
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> { struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelTwo(api); return new KernelTwo(api);
}; };
const char* GetName() const { const char* GetName() const {
@ -119,7 +119,7 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
}; };
struct KernelThree : BaseKernel { struct KernelThree : BaseKernel {
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
if (!TryToGetAttribute("substr", substr_)) { if (!TryToGetAttribute("substr", substr_)) {
substr_ = ""; substr_ = "";
} }
@ -145,7 +145,7 @@ struct KernelThree : BaseKernel {
}; };
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> { struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelThree(api, info); return new KernelThree(api, info);
}; };
const char* GetName() const { const char* GetName() const {