Issue #279: Fix bug that causes TryGetAttribute<std::string> to fail for valid inputs (#287)

This commit is contained in:
Adrian Lizarraga 2022-09-08 17:14:11 -07:00 коммит произвёл GitHub
Родитель abb0cf726f
Коммит a11c8128b2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 86 добавлений и 2 удалений

Просмотреть файл

@ -52,8 +52,9 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
size_t size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
// The status should be a nullptr when querying for the size.
if (status != nullptr) {
api_.ReleaseStatus(status);
return false;
}

Двоичные данные
test/data/custom_op_str_attr_missing_test.onnx Normal file

Двоичный файл не отображается.

Двоичные данные
test/data/custom_op_str_attr_test.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -118,6 +118,53 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
};
};
struct KernelThree : BaseKernel {
KernelThree(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
if (!TryToGetAttribute("substr", substr_)) {
substr_ = "";
}
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_val = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_strs;
GetTensorMutableDataString(api_, ort_, context, input_val, input_strs);
// Setup output
OrtTensorDimensions dimensions(ort_, input_val);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
// Record substring locations in output
for (int64_t i = 0; i < dimensions.Size(); i++) {
out[i] = input_strs[i].find(substr_);
}
}
private:
std::string substr_;
};
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelThree(api, info);
};
const char* GetName() const {
return "CustomOpThree";
};
size_t GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
};
template <typename T>
void _emplace_back(Ort::MemoryInfo& memory_info, std::vector<Ort::Value>& ort_inputs, const std::vector<T>& values, const std::vector<int64_t>& dims) {
ort_inputs.emplace_back(Ort::Value::CreateTensor<T>(
@ -299,6 +346,42 @@ TEST(utils, test_ort_case) {
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
static CustomOpThree op_3rd;
TEST(utils, test_get_str_attr) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
// Input: list of strings
std::vector<TestValue> inputs(1);
inputs[0].name = "input_1";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {3};
inputs[0].values_string = {"look for abc", "abc is first", "not found here"};
// Expected output: location of the substring "abc" in each input string.
std::vector<TestValue> outputs(1);
outputs[0].name = "output_1";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[0].dims = {3};
outputs[0].values_int64 = {9, 0, -1};
std::filesystem::path model_path = "data";
model_path /= "custom_op_str_attr_test.onnx";
AddExternalCustomOp(&op_3rd);
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
// Expected output when the attribute is missing from the node.
std::vector<TestValue> outputs_missing(1);
outputs_missing[0].name = "output_1";
outputs_missing[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs_missing[0].dims = {3};
outputs_missing[0].values_int64 = {0, 0, 0};
std::filesystem::path model_missing_attr_path = "data";
model_missing_attr_path /= "custom_op_str_attr_missing_test.onnx";
TestInference(*ort_env, model_missing_attr_path.c_str(), inputs, outputs_missing, GetLibraryPath());
}
TEST(ustring, tensor_operator) {
OrtValue *tensor;
OrtAllocator* allocator;