From 35408d94b6c1f6c720ad84e6f412fbcec039605d Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Fri, 14 Jul 2023 10:46:35 -0700 Subject: [PATCH] fix the ort version assigment bug (#490) --- includes/custom_op_lite.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index a08e896d..d001ee87 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -43,7 +43,7 @@ class TensorBase { std::string Shape2Str() const { if (shape_.has_value()) { std::string shape_str; - for (const auto& dim: *shape_) { + for (const auto& dim : *shape_) { shape_str.append(std::to_string(dim)); shape_str.append(", "); } @@ -372,10 +372,10 @@ struct Variadic : public TensorBase { break; } tensors_.emplace_back(tensor.release()); - } // for + } // for } } - template + template T* AllocateOutput(size_t ith_output, const std::vector& shape) { auto tensor = std::make_unique>(api_, ctx_, ith_output, false); auto raw_output = tensor.get()->Allocate(shape); @@ -402,6 +402,7 @@ struct Variadic : public TensorBase { const TensorPtr& operator[](size_t indice) const { return tensors_.at(indice); } + private: TensorPtrs tensors_; }; @@ -723,7 +724,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider) : op_name_(op_name), execution_provider_(execution_provider) { - OrtCustomOp::version = GetActiveOrtAPIVersion(); + int act_ver = GetActiveOrtAPIVersion(); + OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };