Issue #288: Prevent copying of OrtApi struct (#290)

This commit is contained in:
Adrian Lizarraga 2022-09-14 09:45:33 -07:00 коммит произвёл GitHub
Родитель 658aa28572
Коммит ae416f6aa6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
60 изменённых файлов: 145 добавлений и 145 удалений

Просмотреть файл

@ -18,8 +18,8 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
const char c_OpDomain[] = "ai.onnx.contrib";
struct BaseKernel {
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
bool HasAttribute(const char* name) const;
@ -37,7 +37,7 @@ struct BaseKernel {
protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
const OrtApi& api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo* info_;
};

Просмотреть файл

@ -3,7 +3,7 @@
struct KernelGaussianBlur : BaseKernel {
KernelGaussianBlur(OrtApi api) : BaseKernel(api) {
KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
@ -78,7 +78,7 @@ struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaus
return "GaussianBlur";
}
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelGaussianBlur(api);
}
};

Просмотреть файл

@ -5,7 +5,7 @@
struct KernelImageReader : BaseKernel {
KernelImageReader(OrtApi api) : BaseKernel(api) {
KernelImageReader(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
@ -49,7 +49,7 @@ struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageR
return "ImageReader";
}
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelImageReader(api);
}
};

Просмотреть файл

@ -8,7 +8,7 @@
struct KernelInverse : BaseKernel {
KernelInverse(OrtApi api) : BaseKernel(api) {
KernelInverse(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
@ -34,7 +34,7 @@ struct KernelInverse : BaseKernel {
};
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelInverse(api);
}

Просмотреть файл

@ -6,7 +6,7 @@
#include "ocos.h"
struct KernelNegPos : BaseKernel {
KernelNegPos(OrtApi api) : BaseKernel(api) {
KernelNegPos(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context){
@ -40,7 +40,7 @@ struct KernelNegPos : BaseKernel {
};
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const{
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelNegPos(api);
}

Просмотреть файл

@ -3,7 +3,7 @@
#include "segment_extraction.hpp"
KernelSegmentExtraction::KernelSegmentExtraction(OrtApi api) : BaseKernel(api) {
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
}
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
@ -51,7 +51,7 @@ ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*inde
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
void* CustomOpSegmentExtraction::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpSegmentExtraction::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelSegmentExtraction(api);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelSegmentExtraction : BaseKernel {
KernelSegmentExtraction(OrtApi api);
KernelSegmentExtraction(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
@ -16,6 +16,6 @@ struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction,
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};

Просмотреть файл

@ -52,7 +52,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
}
}
KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) {
KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
}
void KernelSegmentSum::Compute(OrtKernelContext* context) {
@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) co
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpSegmentSum::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelSegmentSum(api);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelSegmentSum : BaseKernel {
KernelSegmentSum(OrtApi api);
KernelSegmentSum(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
@ -16,6 +16,6 @@ struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentS
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
KernelMaskedFill::KernelMaskedFill(OrtApi api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
}
void KernelMaskedFill::Compute(OrtKernelContext* context) {
@ -51,7 +51,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpMaskedFill::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelMaskedFill(api, info);
};

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map>
struct KernelMaskedFill : BaseKernel {
KernelMaskedFill(OrtApi api, const OrtKernelInfo* info);
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -5,7 +5,7 @@
#include "op_equal_impl.hpp"
#include <string>
KernelStringEqual::KernelStringEqual(OrtApi api) : BaseKernel(api) {
KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringEqual::Compute(OrtKernelContext* context) {
@ -24,7 +24,7 @@ ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) c
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
};
void* CustomOpStringEqual::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const{
void* CustomOpStringEqual::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const{
return new KernelStringEqual(api);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringEqual : BaseKernel {
KernelStringEqual(OrtApi api);
KernelStringEqual(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
@ -16,6 +16,6 @@ struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelString
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_tensor.h"
#include "op_ragged_tensor.hpp"
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(const OrtApi& api) : BaseKernel(api) {
}
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
@ -53,7 +53,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*i
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
void* CustomOpRaggedTensorToSparse::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpRaggedTensorToSparse::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelRaggedTensorToSparse(api);
};
@ -65,7 +65,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
CommonRaggedTensorToDense::CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
}
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
@ -84,7 +84,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
return max_col;
}
KernelRaggedTensorToDense::KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
}
@ -134,7 +134,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
void* CustomOpRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelRaggedTensorToDense(api, info);
};
@ -146,7 +146,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
}
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
@ -193,7 +193,7 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
void* CustomOpStringRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRaggedTensorToDense(api, info);
};

Просмотреть файл

@ -6,7 +6,7 @@
#include "ocos.h"
struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(OrtApi api);
KernelRaggedTensorToSparse(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
@ -15,12 +15,12 @@ struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSp
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
struct CommonRaggedTensorToDense : BaseKernel {
CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
protected:
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
@ -28,7 +28,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
};
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -40,12 +40,12 @@ struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDen
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info);
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
};
@ -54,6 +54,6 @@ struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRagge
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};

Просмотреть файл

@ -8,7 +8,7 @@
#include "re2/re2.h"
#include "string_tensor.h"
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
}
@ -62,7 +62,7 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output);
}
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRegexReplace(api, info);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringRegexReplace : BaseKernel {
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info);
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
protected:
@ -15,7 +15,7 @@ struct KernelStringRegexReplace : BaseKernel {
};
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <vector>
#include <cmath>
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
}
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
@ -79,7 +79,7 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringRegexSplitWithOffsets(api, info);
};

Просмотреть файл

@ -8,12 +8,12 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringRegexSplitWithOffsets : BaseKernel {
KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) {
KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringConcat::Compute(OrtKernelContext* context) {
@ -37,7 +37,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, left_value, output);
}
void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringConcat::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringConcat(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringConcat : BaseKernel {
KernelStringConcat(OrtApi api);
KernelStringConcat(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <regex>
#include "string_tensor.h"
KernelStringECMARegexReplace::KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
@ -70,7 +70,7 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, str_input, output);
}
void* CustomOpStringECMARegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringECMARegexReplace(api, info);
};

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_utils.h"
struct KernelStringECMARegexReplace : BaseKernel {
KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info);
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
protected:
@ -16,7 +16,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
};
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -10,7 +10,7 @@
#include "string_tensor.h"
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
}
@ -84,7 +84,7 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringECMARegexSplitWithOffsets(api, info);
};

Просмотреть файл

@ -9,14 +9,14 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
KernelStringECMARegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info);
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
bool ignore_case_;
};
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include "string_hash.hpp"
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringHash::Compute(OrtKernelContext* context) {
@ -43,7 +43,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
}
}
void* CustomOpStringHash::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringHash::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringHash(api);
};
@ -72,7 +72,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) {
KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringHashFast::Compute(OrtKernelContext* context) {
@ -106,7 +106,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
}
}
void* CustomOpStringHashFast::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringHashFast::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringHashFast(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringHash : BaseKernel {
KernelStringHash(OrtApi api);
KernelStringHash(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
@ -21,12 +21,12 @@ struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHa
};
struct KernelStringHashFast : BaseKernel {
KernelStringHashFast(OrtApi api);
KernelStringHashFast(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_join.hpp"
#include "string_tensor.h"
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringJoin::Compute(OrtKernelContext* context) {
@ -86,7 +86,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, out, output);
}
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringJoin::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringJoin(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringJoin : BaseKernel {
KernelStringJoin(OrtApi api);
KernelStringJoin(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
KernelStringLength::KernelStringLength(OrtApi api) : BaseKernel(api) {
KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringLength::Compute(OrtKernelContext* context) {
@ -27,7 +27,7 @@ void KernelStringLength::Compute(OrtKernelContext* context) {
}
}
void* CustomOpStringLength::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringLength::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringLength(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringLength : BaseKernel {
KernelStringLength(OrtApi api);
KernelStringLength(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath>
#include <algorithm>
KernelStringLower::KernelStringLower(OrtApi api) : BaseKernel(api) {
KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringLower::Compute(OrtKernelContext* context) {
@ -25,7 +25,7 @@ void KernelStringLower::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, X, output);
}
void* CustomOpStringLower::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringLower::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringLower(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringLower : BaseKernel {
KernelStringLower(OrtApi api);
KernelStringLower(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
KernelStringMapping::KernelStringMapping(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api) {
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
auto lines = SplitString(map, "\n", true);
for (const auto& line: lines) {
@ -41,7 +41,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, input_data, output);
}
void* CustomOpStringMapping::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringMapping(api, info);
};

Просмотреть файл

@ -8,14 +8,14 @@
#include <unordered_map>
struct KernelStringMapping : BaseKernel {
KernelStringMapping(OrtApi api, const OrtKernelInfo* info);
KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "string_split.hpp"
#include "string_tensor.h"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {
@ -96,7 +96,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, words, out_text);
}
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringSplit::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringSplit(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringSplit : BaseKernel {
KernelStringSplit(OrtApi api);
KernelStringSplit(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -100,7 +100,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
}
}
KernelStringToVector::KernelStringToVector(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
@ -132,7 +132,7 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
}
}
void* CustomOpStringToVector::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelStringToVector(api, info);
};

Просмотреть файл

@ -29,7 +29,7 @@ class StringToVectorImpl {
};
struct KernelStringToVector : BaseKernel {
KernelStringToVector(OrtApi api, const OrtKernelInfo* info);
KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -37,7 +37,7 @@ struct KernelStringToVector : BaseKernel {
};
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include <cmath>
#include <algorithm>
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) {
KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
}
void KernelStringUpper::Compute(OrtKernelContext* context) {
@ -26,7 +26,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, X, output);
}
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
void* CustomOpStringUpper::CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const {
return new KernelStringUpper(api);
};

Просмотреть файл

@ -7,12 +7,12 @@
#include "string_utils.h"
struct KernelStringUpper : BaseKernel {
KernelStringUpper(OrtApi api);
KernelStringUpper(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -103,7 +103,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
}
}
KernelVectorToString::KernelVectorToString(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
@ -124,7 +124,7 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, mapping_result, output);
}
void* CustomOpVectorToString::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelVectorToString(api, info);
};

Просмотреть файл

@ -33,7 +33,7 @@ class VectorToStringImpl {
};
struct KernelVectorToString : BaseKernel {
KernelVectorToString(OrtApi api, const OrtKernelInfo* info);
KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -41,7 +41,7 @@ struct KernelVectorToString : BaseKernel {
};
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -78,7 +78,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
return result;
}
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
@ -105,7 +105,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBasicTokenizer(api, info);
};

Просмотреть файл

@ -21,14 +21,14 @@ class BasicTokenizer {
};
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BasicTokenizer> tokenizer_;
};
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -259,7 +259,7 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
}
}
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
@ -317,7 +317,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, output_dim, attention_mask);
}
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBertTokenizer(api, info);
}
@ -339,7 +339,7 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
KernelHfBertTokenizer::KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs
@ -373,7 +373,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
SetOutput(context, 2, outer_dims, token_type_ids);
}
void* CustomOpHfBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelHfBertTokenizer(api, info);
}

Просмотреть файл

@ -90,7 +90,7 @@ class BertTokenizer final {
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
protected:
@ -98,7 +98,7 @@ struct KernelBertTokenizer : BaseKernel {
};
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
@ -107,12 +107,12 @@ struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBe
};
struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
};
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -112,7 +112,7 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
return false;
}
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
@ -170,7 +170,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpBertTokenizerDecoder::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBertTokenizerDecoder(api, info);
};

Просмотреть файл

@ -33,7 +33,7 @@ class BertTokenizerDecoder {
};
struct KernelBertTokenizerDecoder : BaseKernel {
KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info);
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizerDecoder> decoder_;
@ -43,7 +43,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
};
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -9,7 +9,7 @@
#include <algorithm>
#include <memory>
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
if (model_data_.empty()) {
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
@ -78,7 +78,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
}
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBlingFireSentenceBreaker(api, info);
};

Просмотреть файл

@ -16,7 +16,7 @@ extern "C" int FreeModel(void* ModelPtr);
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info);
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
using ModelPtr = std::shared_ptr<void>;
@ -26,7 +26,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
};
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -463,7 +463,7 @@ bool IsEmptyUString(const ustring& str) {
KernelBpeTokenizer::KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info)
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
if (vocab.empty()) {
@ -582,7 +582,7 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
void* CustomOpBpeTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelBpeTokenizer(api, info);
}

Просмотреть файл

@ -6,7 +6,7 @@
class VocabData;
struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -18,7 +18,7 @@ struct KernelBpeTokenizer : BaseKernel {
};
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -7,7 +7,7 @@
#include "string_tensor.h"
#include "base64.h"
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes;
@ -100,7 +100,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
}
void* CustomOpSentencepieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelSentencepieceTokenizer(api, info);
};

Просмотреть файл

@ -8,7 +8,7 @@
#include "sentencepiece_processor.h"
struct KernelSentencepieceTokenizer : BaseKernel {
KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
};
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "wordpiece_tokenizer.hpp"
#include "nlohmann/json.hpp"
KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
@ -162,7 +162,7 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
ptr_row_lengths[i] = row_begins[i];
}
void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelWordpieceTokenizer(api, info);
};

Просмотреть файл

@ -11,7 +11,7 @@
#include "string_tensor.h"
struct KernelWordpieceTokenizer : BaseKernel {
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
@ -22,7 +22,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
};
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
}
static py::object BuildPyObjFromTensor(
OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
const shape_t& shape, ONNXTensorElementDataType dtype) {
std::vector<npy_intp> npy_dims;
for (auto n : shape) {
@ -210,7 +210,7 @@ typedef struct {
std::vector<int64_t> dimensions;
} InputInformation;
PyCustomOpKernel::PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info,
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
uint64_t id, const std::vector<std::string>& attrs)
: api_(api),
ort_(api_),
@ -377,7 +377,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
// No need to protect against concurrent access, GIL is doing that.
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
const auto [it_domain_op, success] = PyOp_container().insert(val);
assert(success || it_domain_op->second.length() > 0);
assert(success || !it_domain_op->second.empty());
it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
}

Просмотреть файл

@ -37,11 +37,11 @@ struct PyCustomOpDef {
};
struct PyCustomOpKernel {
PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
void Compute(OrtKernelContext* context);
private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
const OrtApi& api_;
Ort::CustomOpApi ort_;
uint64_t obj_id_;
std::map<std::string, std::string> attrs_values_;
@ -60,7 +60,7 @@ struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel
op_type_ = op;
}
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
};

Просмотреть файл

@ -24,7 +24,7 @@ const char* GetLibraryPath() {
}
struct KernelOne : BaseKernel {
KernelOne(OrtApi api) : BaseKernel(api) {
KernelOne(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
@ -52,7 +52,7 @@ struct KernelOne : BaseKernel {
};
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelOne(api);
};
const char* GetName() const {
@ -73,7 +73,7 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
};
struct KernelTwo : BaseKernel {
KernelTwo(OrtApi api) : BaseKernel(api) {
KernelTwo(const OrtApi& api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
@ -98,7 +98,7 @@ struct KernelTwo : BaseKernel {
};
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelTwo(api);
};
const char* GetName() const {
@ -119,7 +119,7 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
};
struct KernelThree : BaseKernel {
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
if (!TryToGetAttribute("substr", substr_)) {
substr_ = "";
}
@ -145,7 +145,7 @@ struct KernelThree : BaseKernel {
};
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelThree(api, info);
};
const char* GetName() const {