fix the ort version assigment bug (#490)

This commit is contained in:
Wenbing Li 2023-07-14 10:46:35 -07:00 коммит произвёл GitHub
Родитель bab1989644
Коммит 35408d94b6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 6 добавлений и 4 удалений

Просмотреть файл

@ -43,7 +43,7 @@ class TensorBase {
std::string Shape2Str() const {
if (shape_.has_value()) {
std::string shape_str;
for (const auto& dim: *shape_) {
for (const auto& dim : *shape_) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
@ -372,10 +372,10 @@ struct Variadic : public TensorBase {
break;
}
tensors_.emplace_back(tensor.release());
} // for
} // for
}
}
template<typename T>
template <typename T>
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
auto raw_output = tensor.get()->Allocate(shape);
@ -402,6 +402,7 @@ struct Variadic : public TensorBase {
const TensorPtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
private:
TensorPtrs tensors_;
};
@ -723,7 +724,8 @@ struct OrtLiteCustomOp : public OrtCustomOp {
OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
OrtCustomOp::version = GetActiveOrtAPIVersion();
int act_ver = GetActiveOrtAPIVersion();
OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };