This commit is contained in:
Родитель
abb0cf726f
Коммит
a11c8128b2
|
@ -52,8 +52,9 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
|||
size_t size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
|
||||
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
|
||||
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
|
||||
// The status should be a nullptr when querying for the size.
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -118,6 +118,53 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|||
};
|
||||
};
|
||||
|
||||
struct KernelThree : BaseKernel {
|
||||
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
if (!TryToGetAttribute("substr", substr_)) {
|
||||
substr_ = "";
|
||||
}
|
||||
}
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_val = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_strs;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_val, input_strs);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_val);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
|
||||
|
||||
// Record substring locations in output
|
||||
for (int64_t i = 0; i < dimensions.Size(); i++) {
|
||||
out[i] = input_strs[i].find(substr_);
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::string substr_;
|
||||
};
|
||||
|
||||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelThree(api, info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpThree";
|
||||
};
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void _emplace_back(Ort::MemoryInfo& memory_info, std::vector<Ort::Value>& ort_inputs, const std::vector<T>& values, const std::vector<int64_t>& dims) {
|
||||
ort_inputs.emplace_back(Ort::Value::CreateTensor<T>(
|
||||
|
@ -299,6 +346,42 @@ TEST(utils, test_ort_case) {
|
|||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
static CustomOpThree op_3rd;
|
||||
|
||||
TEST(utils, test_get_str_attr) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
// Input: list of strings
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "input_1";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {3};
|
||||
inputs[0].values_string = {"look for abc", "abc is first", "not found here"};
|
||||
|
||||
// Expected output: location of the substring "abc" in each input string.
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "output_1";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {3};
|
||||
outputs[0].values_int64 = {9, 0, -1};
|
||||
|
||||
std::filesystem::path model_path = "data";
|
||||
model_path /= "custom_op_str_attr_test.onnx";
|
||||
AddExternalCustomOp(&op_3rd);
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
// Expected output when the attribute is missing from the node.
|
||||
std::vector<TestValue> outputs_missing(1);
|
||||
outputs_missing[0].name = "output_1";
|
||||
outputs_missing[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs_missing[0].dims = {3};
|
||||
outputs_missing[0].values_int64 = {0, 0, 0};
|
||||
|
||||
std::filesystem::path model_missing_attr_path = "data";
|
||||
model_missing_attr_path /= "custom_op_str_attr_missing_test.onnx";
|
||||
TestInference(*ort_env, model_missing_attr_path.c_str(), inputs, outputs_missing, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(ustring, tensor_operator) {
|
||||
OrtValue *tensor;
|
||||
OrtAllocator* allocator;
|
||||
|
|
Загрузка…
Ссылка в новой задаче