Error Codes are added to catch compilation error and signal recompile.
Remote Tensors are added to ensure direct memory access for NPU
inferencing.
UMD Bypass cache enabled with 2024.4 will eliminate need to disk caching

### Motivation and Context
The changes are needed to ensure backward compatibility
UMD Bypass caching eliminates driver caching
Remote Tensors lead to performance improvement with inferencing on NPU

---------

Co-authored-by: Preetha Veeramalai <preetha.veeramalai@intel.com>
Co-authored-by: Srirammaswamy <srirammaswamy.s@intel.com>
Co-authored-by: saurabh <saurabh1.kale@intel.com>
Co-authored-by: Javier E. Martinez <javier.e.martinez@intel.com>
Co-authored-by: Eric Crawford <eric.r.crawford@intel.com>
Co-authored-by: jatinwadhwa921 <jatin.wadhwa@intel.com>
This commit is contained in:
sfatimar 2024-09-12 03:25:40 +05:30 коммит произвёл GitHub
Родитель b800328628
Коммит 0309c5f02f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 338 добавлений и 43 удалений

Просмотреть файл

@ -21,6 +21,10 @@
message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release")
endif()
if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4)
add_definitions(-DUSE_OVEP_NPU_MEMORY=1)
endif()
if (WIN32)
unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO)
endif()

Просмотреть файл

@ -50,6 +50,8 @@ constexpr const char* HIP = "Hip";
constexpr const char* HIP_PINNED = "HipPinned";
constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
constexpr const char* OpenVINO_RT = "OpenVINO_RT";
constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU";
constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer";
constexpr size_t kAllocAlignment = 256;

Просмотреть файл

@ -145,6 +145,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),

Просмотреть файл

@ -5,6 +5,7 @@
#include <algorithm>
#include <cassert>
#include <fstream>
#include <regex>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
@ -107,12 +108,15 @@ BackendManager::BackendManager(const GlobalContext& global_context,
subgraph_context_,
ep_ctx_handle_);
} catch (const OnnxRuntimeException& ex) {
std::string exception_str = ex.what();
bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos &&
!GetGlobalContext().disable_cpu_fallback &&
!ep_ctx_handle_.IsValidOVEPCtxGraph();
#if defined(OPENVINO_DISABLE_NPU_FALLBACK)
ORT_THROW(ex.what());
eligible_for_cpu_fallback = false;
#else
if (device_type.find("NPU") != std::string::npos &&
!GetGlobalContext().disable_cpu_fallback) {
LOGS_DEFAULT(WARNING) << ex.what();
if (eligible_for_cpu_fallback) {
LOGS_DEFAULT(VERBOSE) << exception_str;
LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
<< "Falling back to OV CPU for execution";
GetGlobalContext().device_type = "CPU";
@ -125,10 +129,32 @@ BackendManager::BackendManager(const GlobalContext& global_context,
} catch (std::string const& msg) {
ORT_THROW(msg);
}
} else {
ORT_THROW(ex.what());
}
#endif
if (!eligible_for_cpu_fallback) {
if (device_type.find("NPU") != std::string::npos &&
exception_str.find("intel_npu") != std::string::npos) {
// Handle NPU device related errors
#ifndef NDEBUG
ORT_THROW(exception_str + "\nModel needs to be recompiled\n");
#else
std::string error_message = "UNKNOWN NPU ERROR";
std::string error_code = "code 0x0";
std::regex error_message_pattern(R"(\bZE_\w*\b)");
std::regex error_code_pattern("code 0x[0-9a-fA-F]+");
std::smatch matches;
if (std::regex_search(exception_str, matches, error_message_pattern)) {
error_message = matches[0];
}
if (std::regex_search(exception_str, matches, error_code_pattern)) {
error_code = matches[0];
}
throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n");
#endif
} else {
ORT_THROW(exception_str);
}
}
}
}
if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) {

Просмотреть файл

@ -48,14 +48,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
// Set the inference_num_threads property of the CPU
SetNumThreads(device_config);
#ifndef NDEBUG
if (IsDebugEnabled()) {
std::string file_name = subgraph_context.subgraph_name + "_static.onnx";
std::fstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(outfile);
}
#endif
try {
std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str;
@ -180,6 +172,11 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type);
}
device_config.emplace(ov::device::properties("NPU", device_property));
#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3)
if (global_context_.export_ep_ctx_blob) {
global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true));
}
#endif
}
}
@ -295,16 +292,104 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
ORT_THROW(msg);
}
} else {
OVTensorPtr graph_input_blob;
try {
graph_input_blob = infer_request->GetTensor(input_name);
} catch (const char* msg) {
ORT_THROW(msg);
if ((global_context_.device_type.find("CPU") != std::string::npos ||
global_context_.device_type.find("GPU") != std::string::npos)) {
OVTensorPtr graph_input_blob;
try {
graph_input_blob = infer_request->GetTensor(input_name);
} catch (const char* msg) {
ORT_THROW(msg);
}
FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_);
} else {
auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name));
auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName();
ov_tensor_data_t ov_tensor_key;
ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name};
if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) {
ov_tensor_key = it->second;
} else {
// Does this make sense for both types of allocators?
auto input = graph_input_info.at(input_idx);
if (allocator_name == OpenVINO_RT_NPU) {
ov_tensor_key.copy_needed = false;
ov_tensor_key.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
(void*)tensor.GetTensorRawData());
} else {
ov_tensor_key.copy_needed = true;
ov_tensor_key.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape());
}
ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_key);
if (ov_tensor_key.copy_needed) {
const char* ort_tensor_data = tensor.GetTensorData<char>();
size_t tensor_data_size = ov_tensor_key.tensor_ptr->get_byte_size();
auto ort_batch_memory_offset = ort_tensor_data + tensor_data_size * batch_slice_idx;
std::memcpy(ov_tensor_key.tensor_ptr->data(), ort_batch_memory_offset, tensor_data_size);
}
try {
infer_request->SetTensor(input_name, ov_tensor_key.tensor_ptr);
} catch (const char* msg) {
ORT_THROW(msg);
}
}
}
FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_);
}
input_idx++;
}
if (global_context_.device_type.find("NPU") != std::string::npos) {
// Set the output blob as remote blob
auto graph_output_info = exe_network_.Get().outputs();
auto output_idx = 0;
for (auto output_info_iter = graph_output_info.begin();
output_info_iter != graph_output_info.end(); ++output_info_iter) {
auto output_names = output_info_iter->get_names();
std::string onnx_output_name;
std::string output_name;
// using the output name retrieved from ONNX original to match with the output names returned by OV tensors
for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) {
onnx_output_name = it->first;
if (output_names.find(onnx_output_name) != output_names.end()) {
// Assigning the output_name
output_name = it->first;
break;
}
}
size_t batch_size = 1;
Ort::UnownedValue tensor = GetOutputTensor(context,
batch_size,
infer_request,
output_name,
subgraph_context_.output_names);
auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName();
ov_tensor_data_t ov_tensor_data;
ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name};
if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) {
ov_tensor_data = it->second;
} else {
auto output = graph_output_info.at(output_idx);
if (allocator_name == OpenVINO_RT_NPU) {
ov_tensor_data.copy_needed = false;
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
(void*)tensor.GetTensorRawData());
} else {
ov_tensor_data.copy_needed = true;
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape());
}
ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_data);
try {
infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr);
} catch (const char* msg) {
ORT_THROW(msg);
}
}
output_idx++;
}
}
// Start Async inference
infer_request->StartAsync();
} catch (const char* msg) {
@ -454,20 +539,42 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
" doesn't exist in the "
"list of OpenVINO output tensor names");
}
try {
graph_output_blob = infer_request->GetTensor(output_name);
} catch (const char* msg) {
ORT_THROW(msg);
}
size_t batch_size = 1;
Ort::UnownedValue output_tensor =
GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names);
auto mem_info = output_tensor.GetTensorMemoryInfo();
if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
return;
if ((global_context_.device_type.find("CPU") != std::string::npos ||
global_context_.device_type.find("GPU") != std::string::npos)) {
try {
graph_output_blob = infer_request->GetTensor(output_name);
} catch (const char* msg) {
ORT_THROW(msg);
}
size_t batch_size = 1;
Ort::UnownedValue output_tensor =
GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names);
auto mem_info = output_tensor.GetTensorMemoryInfo();
if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
return;
} else {
size_t batch_slice = 0;
FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice);
}
} else {
size_t batch_slice = 0;
FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice);
size_t batch_size = 1;
Ort::UnownedValue output_tensor =
GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names);
auto allocator_name = output_tensor.GetTensorMemoryInfo().GetAllocatorName();
ov_tensor_data_t ov_tensor_data;
ort_tensor_key_t ort_tensor_key{output_tensor.GetTensorRawData(), allocator_name};
if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) {
ov_tensor_data = it->second;
} else {
ORT_THROW(log_tag + "Expected all outputs to have associated OV::Tensor's");
}
if (ov_tensor_data.copy_needed) {
auto ort_tensor_data = output_tensor.GetTensorMutableData<char>();
size_t tensor_data_size = ov_tensor_data.tensor_ptr->get_byte_size();
auto ort_batch_memory_offset = ort_tensor_data /*+ tensor_data_size * batch_size*/;
std::memcpy(ort_batch_memory_offset, ov_tensor_data.tensor_ptr->data(), tensor_data_size);
}
}
}

Просмотреть файл

@ -11,6 +11,7 @@
#include <string>
#include <condition_variable>
#include <mutex>
#include <map>
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/openvino/contexts.h"
@ -20,6 +21,11 @@
namespace onnxruntime {
namespace openvino_ep {
struct ov_tensor_data_t {
OVTensorPtr tensor_ptr;
bool copy_needed;
};
class InferRequestsQueue;
class BasicBackend : public IBackend {
public:
@ -60,6 +66,9 @@ class BasicBackend : public IBackend {
#if defined IO_BUFFER_ENABLED
OVRemoteContextPtr remote_context_;
#endif
using ort_tensor_key_t = std::pair<const void*, const std::string>;
std::map<ort_tensor_key_t, ov_tensor_data_t> ort_ov_tensor_map;
};
class InferRequestsQueue {

Просмотреть файл

@ -10,6 +10,9 @@
#include "core/providers/openvino/onnx_ctx_model_helper.h"
#include "core/providers/openvino/ov_versions/capability.h"
#include "openvino/core/version.hpp"
#ifdef USE_OVEP_NPU_MEMORY
#include "core/providers/openvino/ov_allocator.h"
#endif
#define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz))
@ -180,4 +183,18 @@ common::Status OpenVINOExecutionProvider::Compile(
return Status::OK();
}
#ifdef USE_OVEP_NPU_MEMORY
std::vector<AllocatorPtr> OpenVINOExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo npu_allocator_info{
[this](OrtDevice::DeviceId device_id) {
return std::make_unique<OVRTAllocator>(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU);
},
0,
};
// fill in allocator
return std::vector<AllocatorPtr>{CreateAllocator(npu_allocator_info)};
}
#endif
} // namespace onnxruntime

Просмотреть файл

@ -189,7 +189,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider {
const void* GetExecutionHandle() const noexcept override {
return nullptr;
}
#ifdef USE_OVEP_NPU_MEMORY
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
#endif
private:
std::unique_ptr<openvino_ep::GlobalContext> global_context_;
openvino_ep::EPCtxHandler ep_ctx_handle_{};

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (C) Intel Corporation
// Licensed under the MIT License
#ifdef USE_OVEP_NPU_MEMORY
#include "core/providers/openvino/ov_allocator.h"
#include "core/providers/openvino/ov_interface.h"
#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp"
#include "openvino/runtime/intel_npu/properties.hpp"
namespace onnxruntime {
using namespace openvino_ep;
constexpr size_t default_alignment = 4096;
static inline size_t align_up(size_t size, size_t pow2_alignment) {
return (size + pow2_alignment - 1) & ~(pow2_alignment - 1);
}
OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) {
if (device_type == OrtDevice::NPU) {
remote_ctx_ = core_.get_default_context("NPU").as<ov::intel_npu::level_zero::ZeroContext>();
} else {
ORT_THROW("Invalid device type");
}
}
void* OVRTAllocator::Alloc(size_t size) {
try {
size_t alloc_size = align_up(size + sizeof(ov::Tensor*) + default_alignment, default_alignment);
ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8,
{alloc_size}));
uintptr_t data_ptr = reinterpret_cast<uintptr_t>(tensor->data());
ov::Tensor** ptr = reinterpret_cast<ov::Tensor**>(align_up(data_ptr + sizeof(ov::Tensor*), default_alignment));
ptr[-1] = tensor;
return reinterpret_cast<void*>(ptr);
} catch (const ov::Exception& e) {
ORT_THROW(std::string("Alloc failed: ") + e.what());
}
return nullptr;
}
void OVRTAllocator::Free(void* p) {
try {
ov::Tensor** ptr = reinterpret_cast<ov::Tensor**>(p);
delete ptr[-1];
} catch (const ov::Exception& e) {
ORT_THROW(std::string("Free failed: ") + e.what());
}
}
} // namespace onnxruntime
#endif

Просмотреть файл

@ -0,0 +1,24 @@
// Copyright (C) Intel Corporation
// Licensed under the MIT License
#ifdef USE_OVEP_NPU_MEMORY
#pragma once
#include "core/common/inlined_containers.h"
#include "core/framework/allocator.h"
#include "openvino/runtime/remote_context.hpp"
namespace onnxruntime {
class OVRTAllocator : public IAllocator {
public:
OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name);
void* Alloc(size_t size) override;
void Free(void* p) override;
private:
ov::Core& core_;
ov::RemoteContext remote_ctx_;
};
} // namespace onnxruntime
#endif

Просмотреть файл

@ -10,6 +10,7 @@
#include <utility>
#include "openvino/openvino.hpp"
#include "openvino/runtime/intel_npu/properties.hpp"
#include "openvino/pass/convert_fp32_to_fp16.hpp"
#include "openvino/frontend/manager.hpp"

Просмотреть файл

@ -34,10 +34,18 @@ std::chrono::duration<double> OnnxRuntimeTestSession::Run() {
// Randomly pick one OrtValueArray from test_inputs_. (NOT ThreadSafe)
const std::uniform_int_distribution<int>::param_type p(0, static_cast<int>(test_inputs_.size() - 1));
const size_t id = static_cast<size_t>(dist_(rand_engine_, p));
auto& input = test_inputs_.at(id);
auto start = std::chrono::high_resolution_clock::now();
auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(),
output_names_raw_ptr.data(), output_names_raw_ptr.size());
if (!use_device_mem) {
auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(),
output_names_raw_ptr.data(), output_names_raw_ptr.size());
} else {
session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(),
output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size());
}
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration_seconds = end - start;
return duration_seconds;
@ -815,6 +823,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
"[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' "
"should be a boolean i.e. true or false. Default value is false.\n");
}
} else if (key == "use_device_mem") {
if (value == "true" || value == "True") {
use_device_mem = true;
}
} else {
ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n");
}
@ -858,6 +870,27 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
input_names_str_[i] = m.GetInputName(i);
input_names_[i] = input_names_str_[i].c_str();
}
if (use_device_mem) {
Ort::MemoryInfo memory_info = Ort::MemoryInfo("OpenVINO_RT_NPU", OrtArenaAllocator, 0, OrtMemTypeCPUOutput);
custom_allocator_ = std::make_unique<Ort::Allocator>(session_, memory_info);
for (size_t i = 0; i < output_names_raw_ptr.size(); i++) {
Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> output_shape = tensor_info.GetShape();
// free dimensions are treated as 1 if not overridden
for (int64_t& dim : output_shape) {
if (dim == -1) {
dim = 1;
}
}
outputs_.push_back(Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)output_shape.data(),
output_shape.size(), tensor_info.GetElementType()));
}
}
}
template <typename T>
@ -944,9 +977,11 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) {
// iterate over all input nodes
for (size_t i = 0; i < static_cast<size_t>(input_length_); i++) {
Ort::TypeInfo type_info = session_.GetInputTypeInfo(i);
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) {
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
if (!use_device_mem) {
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
}
std::vector<int64_t> input_node_dim = tensor_info.GetShape();
// free dimensions are treated as 1 if not overridden
@ -955,12 +990,18 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) {
dim = 1;
}
}
auto allocator = Ort::AllocatorWithDefaultOptions();
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(),
input_node_dim.size(), tensor_info.GetElementType());
InitializeTensorWithSeed(seed, input_tensor);
PreLoadTestData(0, i, std::move(input_tensor));
if (use_device_mem) {
Ort::Value input_tensor = Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)input_node_dim.data(),
input_node_dim.size(), tensor_info.GetElementType());
InitializeTensorWithSeed(seed, input_tensor);
PreLoadTestData(0, i, std::move(input_tensor));
} else {
auto allocator = Ort::AllocatorWithDefaultOptions();
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(),
input_node_dim.size(), tensor_info.GetElementType());
InitializeTensorWithSeed(seed, input_tensor);
PreLoadTestData(0, i, std::move(input_tensor));
}
}
}
return true;

Просмотреть файл

@ -38,6 +38,8 @@ class OnnxRuntimeTestSession : public TestSession {
std::mt19937 rand_engine_;
std::uniform_int_distribution<int> dist_;
std::vector<std::vector<Ort::Value>> test_inputs_;
std::unique_ptr<Ort::Allocator> custom_allocator_;
std::vector<Ort::Value> outputs_;
std::vector<std::string> output_names_;
// The same size with output_names_.
// TODO: implement a customized allocator, then we can remove output_names_ to simplify this code
@ -46,6 +48,7 @@ class OnnxRuntimeTestSession : public TestSession {
std::vector<std::string> input_names_str_;
const int input_length_;
std::string provider_name_;
bool use_device_mem = false;
};
} // namespace perftest