Reduce default logger usage (#23030)
### Description <!-- Describe your changes. --> We have use cases where multiple sessions are created concurrently. Minimizing the usage of the default logger is important for these scenarios. Wire through the session logger to as many places as possible. The EP logger can also be used once the session is created (can't be used during EP construction/kernel registration but can be used in GetCapability and Compile). ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve logging when there are concurrent sessions.
This commit is contained in:
Родитель
e12421be30
Коммит
708ee8556e
|
@ -8,6 +8,9 @@
|
|||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
|
||||
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
|
||||
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
|
||||
|
@ -33,6 +36,7 @@ class KernelRegistry {
|
|||
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
|
||||
Status TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
// map of type constraint name to required type
|
||||
|
@ -42,6 +46,7 @@ class KernelRegistry {
|
|||
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
|
||||
Status TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
/**
|
||||
|
@ -61,13 +66,15 @@ class KernelRegistry {
|
|||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
|
||||
ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver) {
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger) {
|
||||
const KernelCreateInfo* info;
|
||||
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
|
||||
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
|
||||
return st.IsOK();
|
||||
}
|
||||
|
||||
|
@ -83,6 +90,7 @@ class KernelRegistry {
|
|||
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
// Check whether the types of inputs/outputs of the given node match the extra
|
||||
|
|
|
@ -53,6 +53,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
TransformerLevel level,
|
||||
const SessionOptions& session_options,
|
||||
const IExecutionProvider& execution_provider /*required by constant folding*/,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
|
||||
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
|
||||
|
@ -84,6 +85,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
const SessionOptions& session_options,
|
||||
const SatApplyContextVariant& apply_context,
|
||||
const IExecutionProvider& cpu_execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
|
||||
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
|
||||
|
|
|
@ -138,7 +138,8 @@ class PlannerImpl {
|
|||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
|
||||
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan,
|
||||
const logging::Logger& logger)
|
||||
: context_(&context),
|
||||
plan_(plan),
|
||||
parent_node_(parent_node),
|
||||
|
@ -148,14 +149,15 @@ class PlannerImpl {
|
|||
kernel_create_info_map_(kernel_create_info_map),
|
||||
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
|
||||
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map) {}
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map),
|
||||
logger_(logger) {
|
||||
}
|
||||
|
||||
Status CreatePlan(
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
const IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
#endif
|
||||
const PathString& partition_config_file,
|
||||
const logging::Logger& logger);
|
||||
const PathString& partition_config_file);
|
||||
|
||||
private:
|
||||
gsl::not_null<const ISequentialPlannerContext*> context_;
|
||||
|
@ -183,6 +185,12 @@ class PlannerImpl {
|
|||
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
|
||||
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;
|
||||
|
||||
// logger_ is not currently used in a minimal build
|
||||
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
[[maybe_unused]]
|
||||
#endif
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
|
||||
struct OrtValueInfo {
|
||||
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
|
||||
|
@ -213,6 +221,7 @@ class PlannerImpl {
|
|||
FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point)
|
||||
: ml_value(ort_value), deallocate_point(dealloc_point) {}
|
||||
};
|
||||
|
||||
// freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when
|
||||
// they became free (more recently freed earlier in the list).
|
||||
std::list<FreeBufferInfo> freelist_;
|
||||
|
@ -225,7 +234,8 @@ class PlannerImpl {
|
|||
}
|
||||
|
||||
int& UseCount(OrtValueIndex n) {
|
||||
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size());
|
||||
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(),
|
||||
"invalid value index: ", n, " against size ", ort_value_info_.size());
|
||||
return ort_value_info_[n].usecount;
|
||||
}
|
||||
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }
|
||||
|
@ -335,9 +345,9 @@ class PlannerImpl {
|
|||
// we cannot.
|
||||
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
*reusable_input = Index(p_input_arg->Name());
|
||||
|
@ -361,9 +371,9 @@ class PlannerImpl {
|
|||
// we cannot.
|
||||
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
*reusable_input = Index(p_input_arg->Name());
|
||||
|
@ -397,8 +407,8 @@ class PlannerImpl {
|
|||
}
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -448,8 +458,8 @@ class PlannerImpl {
|
|||
return true;
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs.";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs.";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1198,9 +1208,9 @@ class PlannerImpl {
|
|||
// Otherwise, we cannot reuse the buffer.
|
||||
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1241,9 +1251,9 @@ class PlannerImpl {
|
|||
// Otherwise, we cannot reuse the buffer.
|
||||
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1290,8 +1300,8 @@ class PlannerImpl {
|
|||
}
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1869,8 +1879,7 @@ class PlannerImpl {
|
|||
}
|
||||
|
||||
#ifndef ORT_ENABLE_STREAM
|
||||
void PartitionIntoStreams(const logging::Logger& /*logger*/,
|
||||
const ExecutionProviders& /*execution_providers*/,
|
||||
void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/,
|
||||
const PathString& /*partition_config_file*/) {
|
||||
if (graph_viewer_.NumberOfNodes() > 0) {
|
||||
stream_nodes_.push_back({});
|
||||
|
@ -1915,11 +1924,11 @@ class PlannerImpl {
|
|||
|
||||
#else
|
||||
|
||||
void
|
||||
PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
|
||||
const PathString& partition_config_file) {
|
||||
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
|
||||
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
|
||||
void PartitionIntoStreams(const ExecutionProviders& execution_providers,
|
||||
const PathString& partition_config_file) {
|
||||
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file);
|
||||
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_,
|
||||
context_->GetExecutionOrder());
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
plan_.node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
|
||||
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
|
||||
|
@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan(
|
|||
#ifdef ORT_ENABLE_STREAM
|
||||
const IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
#endif
|
||||
const PathString& partition_config_file,
|
||||
const logging::Logger& logger) {
|
||||
const PathString& partition_config_file) {
|
||||
// 1. partition graph into streams
|
||||
PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file);
|
||||
PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file);
|
||||
|
||||
// 2. initialize the plan based on stream partition result
|
||||
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
|
||||
|
@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan(
|
|||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
|
||||
kernel_create_info_map, subgraphs_kernel_create_info_maps,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map, context, *plan);
|
||||
ort_value_name_idx_map, context, *plan, logger);
|
||||
|
||||
return planner.CreatePlan(
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
stream_handle_registry,
|
||||
#endif
|
||||
partition_config_file,
|
||||
logger);
|
||||
partition_config_file);
|
||||
}
|
||||
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
|
|
|
@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) {
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) {
|
||||
// automatic conversion from const std::vector&
|
||||
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
|
||||
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex());
|
||||
|
@ -83,7 +84,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
|
|||
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
|
||||
for (auto& consumer_node : consumer_nodes) {
|
||||
candidates.push(consumer_node->Index());
|
||||
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
|
||||
LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -159,9 +160,9 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
|
|||
|
||||
if (place_in_cpu) {
|
||||
cpu_nodes.insert(cur);
|
||||
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
|
||||
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
|
||||
<< " capable of executing this node";
|
||||
LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
|
||||
<< " because the CPU execution path is deemed faster than overhead involved with execution "
|
||||
"on other EPs capable of executing this node";
|
||||
for (auto* output : node->OutputDefs()) {
|
||||
cpu_output_args.insert(output);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
|
||||
/**
|
||||
Returns a list of nodes that are preferred on CPU.
|
||||
|
@ -19,6 +22,7 @@ namespace onnxruntime {
|
|||
*/
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep,
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
||||
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) {
|
||||
auto& current_ep = params.current_ep.get();
|
||||
const auto& ep_type = current_ep.Type();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) {
|
||||
LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
|
||||
LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
|
||||
"the layout transformer.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
|||
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
|
||||
const KernelLookup kernel_lookup{ep_type,
|
||||
kernel_registries_for_ep,
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver()};
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver(),
|
||||
logger};
|
||||
|
||||
auto& graph = params.graph.get();
|
||||
auto& capabilities = params.capabilities.get();
|
||||
|
@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
|||
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
|
||||
const KernelRegistryManager& kernel_registry_mgr,
|
||||
const IExecutionProvider& current_ep,
|
||||
const logging::Logger& logger,
|
||||
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
|
||||
const auto& ep_type = current_ep.Type();
|
||||
|
||||
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
|
||||
const KernelLookup kernel_lookup{ep_type,
|
||||
kernel_registries_for_ep,
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver()};
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver(),
|
||||
logger};
|
||||
|
||||
// TODO: Provide EP with a capability to look inside the functions.
|
||||
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
|
||||
|
@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id,
|
||||
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
|
||||
const layout_transformation::DebugGraphFn& debug_graph_fn) {
|
||||
const layout_transformation::DebugGraphFn& debug_graph_fn,
|
||||
const logging::Logger& logger) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
if (graph.NumberOfNodes() == 0) {
|
||||
|
@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
// we pass through the FuncManager from the top level graph
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
|
||||
transform_layout_fn, debug_graph_fn));
|
||||
transform_layout_fn, debug_graph_fn, logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
std::cref(transform_layout_fn),
|
||||
std::cref(debug_graph_fn)};
|
||||
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
|
||||
if (n != nullptr) {
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
|
||||
nodes_to_compile.push_back(n);
|
||||
capabilities_to_compile.push_back(std::move(capability));
|
||||
} else {
|
||||
|
@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
|
|||
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
|
||||
const KernelRegistryManager& kernel_registry_mgr,
|
||||
Graph& graph,
|
||||
const logging::Logger& logger,
|
||||
InlinedHashSet<std::string>& not_inlined,
|
||||
size_t& inlined_count) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
|
@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
|||
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
|
||||
kernel_registry_mgr,
|
||||
*subgraph,
|
||||
logger,
|
||||
not_inlined,
|
||||
inlined_count));
|
||||
}
|
||||
|
@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
|||
InlinedHashSet<NodeIndex> claimed_by_ep;
|
||||
for (const auto& ep : execution_providers) {
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger,
|
||||
capabilities));
|
||||
for (auto& capability : capabilities) {
|
||||
const auto& nodes = capability->sub_graph->nodes;
|
||||
if (nodes.size() == 1) {
|
||||
|
@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
|
|||
|
||||
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
|
||||
const ExecutionProviders& execution_providers,
|
||||
KernelRegistryManager& kernel_registry_manager) {
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const logging::Logger& logger) {
|
||||
bool modified_graph = false;
|
||||
|
||||
auto& graph = partition_params.graph.get();
|
||||
|
@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
|
|||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager,
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id,
|
||||
transform_layout_function,
|
||||
partition_params.debug_graph_fn));
|
||||
partition_params.debug_graph_fn,
|
||||
logger));
|
||||
}
|
||||
|
||||
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
|
||||
|
@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
|
|||
|
||||
static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
IExecutionProvider& current_ep) {
|
||||
IExecutionProvider& current_ep,
|
||||
const logging::Logger& logger) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
auto& graph = partition_params.graph.get();
|
||||
|
@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
auto& subgraph = *entry.second;
|
||||
PartitionParams subgraph_partition_params = partition_params;
|
||||
subgraph_partition_params.graph = std::ref(subgraph);
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep,
|
||||
logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
};
|
||||
// clang-format on
|
||||
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
// Simplified partitioning where custom EPs may produce compiled nodes.
|
||||
static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
|
||||
const ExecutionProviders& execution_providers,
|
||||
KernelRegistryManager& kernel_registry_manager) {
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const logging::Logger& logger) {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : execution_providers) {
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
|
|||
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
|
||||
kernel_registry_manager,
|
||||
graph,
|
||||
logger,
|
||||
not_inlined,
|
||||
inlined_count));
|
||||
|
||||
|
@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
|||
|
||||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
|
||||
providers_, kernel_registry_mgr_));
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
|
||||
|
||||
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
|
||||
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
|
@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
|
||||
#endif //! defined(ORT_MINIMAL_BUILD)
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params,
|
||||
providers_, kernel_registry_mgr_));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
|
|||
public:
|
||||
KernelLookup(ProviderType provider_type,
|
||||
gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver)
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger)
|
||||
: provider_type_{provider_type},
|
||||
kernel_registries_{kernel_registries},
|
||||
kernel_type_str_resolver_{kernel_type_str_resolver} {
|
||||
kernel_type_str_resolver_{kernel_type_str_resolver},
|
||||
logger_{logger} {
|
||||
ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified.");
|
||||
}
|
||||
|
||||
const KernelCreateInfo* LookUpKernel(const Node& node) const override {
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
for (const auto& registry : kernel_registries_) {
|
||||
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_,
|
||||
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_,
|
||||
&kernel_create_info);
|
||||
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
|
||||
return kernel_create_info;
|
||||
|
@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
|
|||
ProviderType provider_type_;
|
||||
const gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries_;
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver_;
|
||||
const logging::Logger& logger_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
const auto& node_provider = node.GetExecutionProviderType();
|
||||
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
|
||||
|
@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
std::ostream_iterator<std::string>(oss, "\n"));
|
||||
oss << ")";
|
||||
|
||||
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
|
||||
}
|
||||
|
||||
|
@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
|
||||
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out);
|
||||
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out);
|
||||
}
|
||||
|
||||
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out);
|
||||
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out);
|
||||
}
|
||||
|
||||
static bool KernelDefCompatible(int version, const KernelDef& kernel_def,
|
||||
|
@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
|
|||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider));
|
||||
if (out) *out = nullptr;
|
||||
|
@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
|
|||
std::ostream_iterator<std::string>(oss, "\n"));
|
||||
oss << ")";
|
||||
|
||||
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistr
|
|||
}
|
||||
#endif
|
||||
|
||||
Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
||||
Status KernelRegistryManager::SearchKernelRegistry(const Node& node, const logging::Logger& logger,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const {
|
||||
Status status;
|
||||
|
||||
|
@ -82,7 +82,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
}
|
||||
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
|
||||
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
}
|
||||
|
||||
if (p != nullptr) {
|
||||
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
|
||||
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
|
||||
}
|
||||
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r,
|
||||
const Node& node,
|
||||
const std::string& provider_type,
|
||||
const logging::Logger& logger) {
|
||||
const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
|
||||
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
|
||||
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver());
|
||||
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(),
|
||||
logger);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -67,13 +67,14 @@ class KernelRegistryManager {
|
|||
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status SearchKernelRegistry(const Node& node,
|
||||
Status SearchKernelRegistry(const Node& node, const logging::Logger& logger,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
||||
/**
|
||||
* Whether this node can be run on this provider
|
||||
*/
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status CreateKernel(const Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
|
|
|
@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
|
|||
bool saving_ort_format) {
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
|
||||
if (!status.IsOK() && saving_ort_format) {
|
||||
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
|
||||
// in that case we assigned the node to that EP but do not compile it into a fused node.
|
||||
|
@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
|
|||
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
|
||||
// if that's not possible for some reason we can fallback to the CPU EP implementation.
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
|
|
@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
// Create execution frame for executing constant nodes.
|
||||
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
|
||||
is_sparse_initializer_check);
|
||||
is_sparse_initializer_check, logger);
|
||||
#else
|
||||
// Create execution frame for executing constant nodes.
|
||||
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
|
||||
[](std::string const&) { return false; });
|
||||
OptimizerExecutionFrame::Info info(
|
||||
{node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; },
|
||||
logger);
|
||||
#endif
|
||||
|
||||
std::vector<int> fetch_mlvalue_idxs;
|
||||
|
|
|
@ -190,6 +190,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
TransformerLevel level,
|
||||
const SessionOptions& session_options,
|
||||
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
|
||||
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
|
||||
|
@ -404,7 +405,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
}
|
||||
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
|
||||
logger);
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
|
@ -437,6 +439,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
const SessionOptions& session_options,
|
||||
const SatApplyContextVariant& apply_context,
|
||||
const IExecutionProvider& cpu_execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
|
||||
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
|
||||
|
@ -490,7 +493,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
#ifndef DISABLE_CONTRIB_OPS
|
||||
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
|
||||
logger);
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
|
|
|
@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
|
|||
// going to a node that will need a Cast.
|
||||
//
|
||||
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph,
|
||||
const KernelRegistry& cpu_kernel_registry,
|
||||
const logging::Logger& logger) {
|
||||
// we can check if it's an isolated fp16 node
|
||||
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
|
||||
// does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node),
|
||||
|
@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto lookup_status = cpu_kernel_registry.TryFindKernel(
|
||||
kCpuExecutionProvider, node.OpType(), node.Domain(),
|
||||
node.SinceVersion(), type_constraint_map, &kernel_create_info);
|
||||
node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
|
||||
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
return false;
|
||||
}
|
||||
|
||||
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry,
|
||||
const logging::Logger& logger) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) {
|
||||
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
|
||||
node.SetExecutionProviderType("");
|
||||
}
|
||||
|
@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
return dst_bit_length <= src_bit_length;
|
||||
}
|
||||
|
||||
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
|
||||
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") ||
|
||||
(*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
if (force_cpu_fp32_)
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_));
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger));
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -44,7 +44,9 @@ NhwcConvLookup(
|
|||
return &(iter->second);
|
||||
}
|
||||
|
||||
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept
|
||||
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator,
|
||||
std::shared_ptr<KernelRegistry> cpu_kernel_registry,
|
||||
const logging::Logger& logger) noexcept
|
||||
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
|
||||
if (!cpu_kernel_registry) {
|
||||
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
|
||||
|
@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
|
||||
qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
|
||||
qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
|
||||
qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
|
||||
qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
|
||||
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
|
||||
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
|
||||
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
|
||||
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
|
|
@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed.
|
|||
class NhwcTransformer : public GraphTransformer {
|
||||
private:
|
||||
public:
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept;
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry,
|
||||
const logging::Logger& logger) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Usually called right after constructor, it shows whether
|
||||
|
|
|
@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
const InitializedTensorSet& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger)
|
||||
: execution_provider_(execution_provider),
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func) {
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func),
|
||||
logger_(logger) {
|
||||
allocator_ptr_ = std::make_shared<CPUAllocator>();
|
||||
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
|
||||
|
||||
|
@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
|
||||
const std::filesystem::path& /* model_path */,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger)
|
||||
: execution_provider_(execution_provider),
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func) {
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func),
|
||||
logger_(logger) {
|
||||
allocator_ptr_ = std::make_shared<CPUAllocator>();
|
||||
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
|
||||
|
||||
|
@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const {
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out);
|
||||
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out);
|
||||
}
|
||||
|
||||
static Status TryCreateKernel(const Node& node,
|
||||
|
@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node,
|
|||
FuncManager& funcs_mgr,
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const ConfigOptions& config_options,
|
||||
const logging::Logger& logger,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger,
|
||||
&kernel_create_info));
|
||||
|
||||
static const AllocatorMap dummy_allocators;
|
||||
|
@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption
|
|||
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
|
||||
FuncManager func;
|
||||
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_,
|
||||
op_kernel);
|
||||
|
||||
// Kernel found in the CPU kernel registry
|
||||
|
|
|
@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
const InitializedTensorSet& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Info(const std::vector<const Node*>& nodes,
|
||||
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
~Info() = default;
|
||||
|
||||
|
@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
std::unique_ptr<NodeIndexInfo> node_index_info_;
|
||||
const IExecutionProvider& execution_provider_;
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info);
|
||||
};
|
||||
|
|
|
@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion(
|
|||
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
|
||||
}
|
||||
|
||||
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
||||
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) {
|
||||
constexpr size_t w_idx = 1;
|
||||
constexpr size_t w_zp_idx = 9;
|
||||
constexpr size_t r_idx = 2;
|
||||
|
@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) ||
|
||||
!graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) ||
|
||||
r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
<< " cannot locate recurrence tensor of const int8 type,"
|
||||
<< " int8 overflow might impact precision !";
|
||||
return false;
|
||||
|
@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) ||
|
||||
!graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) ||
|
||||
r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
<< " unable to locate recurrence tensor or its zero point value,"
|
||||
<< " int8 overflow might impact precision !";
|
||||
return false;
|
||||
|
@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int
|
|||
if (graph_utils::IsSupportedOptypeVersionAndDomain(
|
||||
op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) {
|
||||
// This one has two set of quantized arguments
|
||||
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph);
|
||||
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger);
|
||||
continue; // go on to next operator node
|
||||
}
|
||||
|
||||
|
|
|
@ -291,7 +291,8 @@ SelectorManager::SelectorManager() {
|
|||
InitializeSelectorsMap();
|
||||
}
|
||||
|
||||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
|
||||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer,
|
||||
const logging::Logger& logger) const {
|
||||
std::vector<NodeGroup> qdq_selections;
|
||||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
|
@ -313,7 +314,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
|
||||
if (!versions.empty()) {
|
||||
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
|
||||
LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -329,7 +330,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
}
|
||||
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) {
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
|
@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
|||
|
||||
// Get QDQ NodeUnits first
|
||||
QDQ::SelectorManager selector_mgr;
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger);
|
||||
|
||||
for (const auto& qdq_selection : qdq_selections) {
|
||||
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
class GraphViewer;
|
||||
class Node;
|
||||
|
||||
|
@ -65,7 +67,7 @@ class SelectorManager {
|
|||
|
||||
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
|
||||
// Can be used in QDQ support in different EPs
|
||||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;
|
||||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const;
|
||||
|
||||
private:
|
||||
Selectors qdq_selectors_;
|
||||
|
@ -88,7 +90,7 @@ class SelectorManager {
|
|||
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
|
||||
// library whereas it should be able to be used by an EP with no dependency on optimizers.
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer);
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -17,13 +17,22 @@ class TransformerMemcpyImpl {
|
|||
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
|
||||
: graph_(graph), provider_(provider) {}
|
||||
|
||||
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);
|
||||
bool ModifyGraph(const KernelRegistryManager& schema_registries,
|
||||
const logging::Logger& logger,
|
||||
int& copy_node_counter);
|
||||
|
||||
private:
|
||||
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
|
||||
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
|
||||
void ProcessDefs(onnxruntime::Node& node,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger);
|
||||
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
const logging::Logger& logger);
|
||||
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
|
||||
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);
|
||||
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
|
||||
const InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger);
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl);
|
||||
|
@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
// find defs that require copy
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
// as we process the defs, collect all the initializers consumed at the current graph level
|
||||
ProcessDefs(node, kernel_registries, initializers_consumed);
|
||||
ProcessDefs(node, kernel_registries, initializers_consumed, logger);
|
||||
}
|
||||
|
||||
// for initializers shared by different providers, create dups
|
||||
if (ProcessInitializers(kernel_registries, initializers_consumed))
|
||||
if (ProcessInitializers(kernel_registries, initializers_consumed, logger))
|
||||
modified = true;
|
||||
|
||||
for (auto arg : graph_.GetInputs())
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : non_provider_input_defs_)
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : non_provider_output_defs_)
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : graph_.GetInputs())
|
||||
// For inputs we need to create a copy node only when the input is connected to both provider
|
||||
|
@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
return modified;
|
||||
}
|
||||
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed) {
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger) {
|
||||
auto node_provider_type = node.GetExecutionProviderType();
|
||||
if ((node_provider_type == provider_) ||
|
||||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
|
||||
|
@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
provider_nodes_.insert(&node);
|
||||
// note KernelCreateInfo might be nullptr for custom kernel
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci));
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci));
|
||||
|
||||
bool is_implicit_input = false;
|
||||
auto process_inputs =
|
||||
|
@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
}
|
||||
|
||||
// for non_provider defs, collect the nodes that expect it is provider tensor as input/output.
|
||||
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) {
|
||||
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
const logging::Logger& logger) {
|
||||
for (auto& it : graph_.Nodes()) {
|
||||
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
|
||||
auto input_it =
|
||||
|
@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
|
|||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
|
||||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci));
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
|
||||
if (arg_input_index != -1) {
|
||||
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
|
||||
}
|
||||
|
@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co
|
|||
// We duplicate any initializer that is used by both provider nodes and non-provider nodes
|
||||
// to ensure that provider nodes and non-provider nodes don't share initializers, as they
|
||||
// need to stay in different memory locations.
|
||||
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) {
|
||||
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries,
|
||||
const InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger) {
|
||||
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacements;
|
||||
for (const auto& pair : initializers_consumed) {
|
||||
const auto& name = pair.first;
|
||||
|
@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
|
|||
auto dup_replacements = replacements;
|
||||
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci);
|
||||
auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
if (kci == nullptr) continue;
|
||||
if (kci->kernel_def == nullptr) continue;
|
||||
|
|
|
@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
|
|||
|
||||
const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node);
|
||||
if (cann_kernel_def == nullptr) {
|
||||
LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
|
||||
<< " node name: " << node.Name();
|
||||
LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
|
||||
<< " node name: " << node.Name();
|
||||
continue;
|
||||
}
|
||||
|
||||
candidates.push_back(node.Index());
|
||||
}
|
||||
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger());
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
continue;
|
||||
|
|
|
@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
|
||||
// These are usually shape related computation subgraphs
|
||||
// Following logic can be extended for other EPs
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
|
|
|
@ -62,7 +62,8 @@ namespace Dml
|
|||
const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
|
||||
const auto kernel_lookup = onnxruntime::KernelLookup{provider_type,
|
||||
gsl::make_span(®istry, 1),
|
||||
kernel_type_str_resolver};
|
||||
kernel_type_str_resolver,
|
||||
logger};
|
||||
|
||||
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos;
|
||||
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;
|
||||
|
|
|
@ -54,7 +54,8 @@ namespace Dml
|
|||
const auto kernelLookup = onnxruntime::KernelLookup(
|
||||
providerType,
|
||||
gsl::make_span(®istry, 1),
|
||||
kernelTypeStrResolver);
|
||||
kernelTypeStrResolver,
|
||||
logger);
|
||||
|
||||
onnxruntime::GraphViewer graphViewer(graph);
|
||||
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -95,7 +95,7 @@ namespace Dml
|
|||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
|
||||
{
|
||||
#ifdef ENABLE_GRAPH_COMPILATION
|
||||
return m_impl->GetCapability(graph, kernel_lookup);
|
||||
return m_impl->GetCapability(graph, kernel_lookup, *GetLogger());
|
||||
#else
|
||||
return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup);
|
||||
#endif
|
||||
|
@ -876,7 +876,8 @@ namespace Dml
|
|||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
ExecutionProviderImpl::GetCapability(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
const onnxruntime::logging::Logger& logger) const
|
||||
{
|
||||
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
|
||||
|
@ -900,7 +901,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
// Get the list of nodes that should stay on the CPU
|
||||
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes);
|
||||
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger);
|
||||
|
||||
for (size_t nodeIndex : toplogicalOrder)
|
||||
{
|
||||
|
|
|
@ -88,7 +88,8 @@ namespace Dml
|
|||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
GetCapability(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
const onnxruntime::logging::Logger& logger
|
||||
) const;
|
||||
|
||||
uint32_t GetSupportedDeviceDataTypeMask() const;
|
||||
|
|
|
@ -818,7 +818,7 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
|
|||
candidates.push_back(node.Index());
|
||||
tenative_candidates.push_back(node.Index());
|
||||
}
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0) {
|
||||
|
|
|
@ -32,8 +32,16 @@ namespace nnapi {
|
|||
|
||||
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices,
|
||||
TargetDeviceOption target_device_option)
|
||||
: nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique<Model>(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) {
|
||||
TargetDeviceOption target_device_option,
|
||||
const logging::Logger& logger)
|
||||
: nnapi_(nnapi_handle),
|
||||
graph_viewer_(graph_viewer),
|
||||
nnapi_model_{std::make_unique<Model>(nnapi_handle)},
|
||||
shaper_{graph_viewer},
|
||||
nnapi_target_devices_(nnapi_target_devices),
|
||||
target_device_option_(target_device_option),
|
||||
nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)),
|
||||
logger_(logger) {
|
||||
nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_;
|
||||
}
|
||||
|
||||
|
@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
|
|||
}
|
||||
|
||||
void ModelBuilder::PreprocessNodeUnits() {
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
|
||||
}
|
||||
|
||||
// Help to get all quantized operators' input and the NodeUnit(s) using the input
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
|
||||
struct NnApi;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
class GraphViewer;
|
||||
enum class DataLayout;
|
||||
class NodeUnit;
|
||||
|
@ -31,7 +33,8 @@ class ModelBuilder {
|
|||
using Shape = Shaper::Shape;
|
||||
|
||||
ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option);
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option,
|
||||
const logging::Logger& logger);
|
||||
|
||||
common::Status Compile(std::unique_ptr<Model>& model);
|
||||
|
||||
|
@ -173,6 +176,9 @@ class ModelBuilder {
|
|||
// <1,1> <1,2> <1,3>
|
||||
InlinedVector<std::pair<size_t, int32_t>> operations_recorder_;
|
||||
#endif
|
||||
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// Convert the ONNX model to ANeuralNetworksModel
|
||||
common::Status Prepare();
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {}
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const {
|
||||
const auto& logger = *GetLogger();
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
// TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators)
|
||||
|
@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL;
|
||||
#endif
|
||||
}();
|
||||
LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
|
||||
LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
|
||||
|
||||
const nnapi::OpSupportCheckParams params{
|
||||
android_feature_level,
|
||||
|
@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
};
|
||||
|
||||
if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) {
|
||||
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
|
||||
LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
|
||||
<< params.android_feature_level
|
||||
<< "] is lower than minimal supported NNAPI API feature level ["
|
||||
<< ORT_NNAPI_MIN_API_LEVEL
|
||||
|
@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit to be checked for multiple times
|
||||
|
@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
node_unit_supported_result[node_unit] = supported;
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported
|
||||
LOGS(logger, VERBOSE) << "Node supported: [" << supported
|
||||
<< "] Operator type: [" << node.OpType()
|
||||
<< "] index: [" << node.Index()
|
||||
<< "] name: [" << node.Name()
|
||||
|
@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
// If the graph is partitioned in multiple subgraphs, and this may impact performance,
|
||||
// we want to give users a summary message at warning level.
|
||||
if (num_of_partitions > 1) {
|
||||
LOGS_DEFAULT(WARNING) << summary_msg;
|
||||
LOGS(logger, WARNING) << summary_msg;
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << summary_msg;
|
||||
LOGS(logger, INFO) << summary_msg;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context,
|
|||
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
using namespace android::nn::wrapper;
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
|
||||
Node& fused_node = fused_node_and_graph.fused_node;
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
|
||||
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_);
|
||||
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger);
|
||||
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
|
||||
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
|
||||
|
||||
|
|
|
@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
|
|||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger);
|
||||
|
||||
std::unordered_set<const NodeUnit*> seen_node_units;
|
||||
const auto& node_indices = src_graph.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
|
|||
// valid throughout the lifetime of the ModelBuilder
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This name must be same with the EPContext node name
|
||||
const auto& graph_name = fused_node.Name();
|
||||
|
|
|
@ -718,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// remove is_qnn_ctx_model related code
|
||||
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map,
|
||||
|
|
|
@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
// For ROCM EP, exclude the subgraph that is preferred to be placed in CPU
|
||||
// These are usually shape related computation subgraphs
|
||||
// Following logic can be extended for other EPs
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
|
|
|
@ -294,7 +294,8 @@ std::unique_ptr<IDataTransfer> CreateGPUDataTransfer();
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger);
|
||||
|
||||
std::string GetEnvironmentVar(const std::string& var_name);
|
||||
|
||||
|
@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<UInt4x2>() {
|
|||
|
||||
namespace QDQ {
|
||||
inline std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer* graph_viewer) {
|
||||
return g_host->QDQ__GetAllNodeUnits(graph_viewer);
|
||||
GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) {
|
||||
return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger);
|
||||
}
|
||||
} // namespace QDQ
|
||||
|
||||
|
|
|
@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) {
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) {
|
||||
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) {
|
||||
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
}
|
||||
|
||||
namespace profiling {
|
||||
|
|
|
@ -202,7 +202,8 @@ struct ProviderHost {
|
|||
|
||||
virtual std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) = 0;
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) = 0;
|
||||
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
|
||||
|
@ -890,7 +891,7 @@ struct ProviderHost {
|
|||
virtual std::unique_ptr<Node__EdgeIterator> NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0;
|
||||
|
||||
virtual std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0;
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0;
|
||||
|
||||
// Model
|
||||
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
|
||||
|
|
|
@ -34,7 +34,8 @@ namespace onnxruntime {
|
|||
|
||||
namespace vsi {
|
||||
namespace npu {
|
||||
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) {
|
||||
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger)
|
||||
: graph_viewer_(graph_viewer), logger_(logger) {
|
||||
Prepare();
|
||||
context_ = tim::vx::Context::Create();
|
||||
graph_ = context_->CreateGraph();
|
||||
|
@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g
|
|||
}
|
||||
|
||||
bool GraphEP::Prepare() {
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
|
||||
for (const auto& node_unit : node_unit_holder_) {
|
||||
auto quant_op_type = util::GetQuantizedOpType(*node_unit);
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ struct NodeIOInfo {
|
|||
|
||||
class GraphEP {
|
||||
public:
|
||||
explicit GraphEP(const GraphViewer& graph_viewer);
|
||||
explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
~GraphEP() {}
|
||||
|
||||
bool Prepare();
|
||||
|
@ -104,6 +104,7 @@ class GraphEP {
|
|||
// In the form of {input_name, [NodeUnit(s) using the input]}
|
||||
std::unordered_map<std::string, std::vector<const NodeUnit*>> all_quantized_op_inputs_;
|
||||
const GraphViewer& graph_viewer_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is
|
||||
// valid throughout the lifetime of the ModelBuilder
|
||||
|
|
|
@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {}
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const {
|
||||
const auto& logger = *GetLogger();
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
if (graph_viewer.IsSubgraph()) {
|
||||
|
@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit to be checked for multiple times
|
||||
|
@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
}
|
||||
|
||||
Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
||||
OrtKernelContext* context) {
|
||||
OrtKernelContext* context,
|
||||
const logging::Logger& logger) {
|
||||
Ort::KernelContext ctx(context);
|
||||
size_t num_in = ctx.GetInputCount();
|
||||
const size_t num_inputs = graph_ep->GetGraphInputs().size();
|
||||
|
@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
|||
}
|
||||
|
||||
if (!graph_ep->GetGraph()->Run()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to run graph.";
|
||||
LOGS(logger, ERROR) << "Failed to run graph.";
|
||||
}
|
||||
for (size_t i = 0; i < ctx.GetOutputCount(); i++) {
|
||||
auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor;
|
||||
|
@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
|||
|
||||
Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_viewer = fused_node_graph.filtered_graph;
|
||||
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer);
|
||||
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer, logger);
|
||||
|
||||
for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
|
||||
LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
|
||||
<< graph_viewer.IsInitializedTensor(tensor->Name());
|
||||
auto input = std::make_shared<vsi::npu::GraphIOInfo>();
|
||||
input->name = tensor->Name();
|
||||
|
@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
graph_ep->GetGraphInputs().push_back(input);
|
||||
}
|
||||
for (auto tensor : graph_viewer.GetOutputs()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
|
||||
LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
|
||||
auto output = std::make_shared<vsi::npu::GraphIOInfo>();
|
||||
output->name = tensor->Name();
|
||||
output->is_initializer = false;
|
||||
|
@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
if (node != &node_unit.GetNode()) {
|
||||
continue;
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]";
|
||||
LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]";
|
||||
vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit);
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(INFO) << "Verifying graph";
|
||||
LOGS(logger, INFO) << "Verifying graph";
|
||||
graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile();
|
||||
if (!graph_ep->GetCompiled()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to verify graph.";
|
||||
LOGS(logger, ERROR) << "Failed to verify graph.";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "Graph has been verified successfully.";
|
||||
LOGS(logger, INFO) << "Graph has been verified successfully.";
|
||||
}
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
|
@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
[graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */,
|
||||
OrtKernelContext* context) {
|
||||
std::lock_guard<std::mutex> lock(this->GetMutex());
|
||||
Status res = ComputeStateFunc(graph_ep.get(), context);
|
||||
Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger());
|
||||
return res;
|
||||
};
|
||||
|
||||
|
|
|
@ -798,7 +798,8 @@ std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapa
|
|||
candidates.push_back(node.Index());
|
||||
tenative_candidates.push_back(node.Index());
|
||||
}
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
|
||||
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0) {
|
||||
|
|
|
@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit(
|
|||
std::vector<std::unique_ptr<ComputeCapability>> XnnpackExecutionProvider::GetCapability(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const IKernelLookup& /*kernel_lookup*/) const {
|
||||
const auto& logger = *GetLogger();
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
|
||||
|
||||
std::shared_ptr<KernelRegistry> registry = GetKernelRegistry();
|
||||
|
@ -268,7 +269,7 @@ std::vector<std::unique_ptr<ComputeCapability>> XnnpackExecutionProvider::GetCap
|
|||
// Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph, logger);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit being checked for multiple times
|
||||
|
|
|
@ -1644,7 +1644,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations(
|
|||
level <= static_cast<int>(session_options.graph_optimization_level);
|
||||
++level) {
|
||||
const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild(
|
||||
static_cast<TransformerLevel>(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep,
|
||||
static_cast<TransformerLevel>(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger,
|
||||
optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors);
|
||||
|
||||
for (const auto& transformer : transformers) {
|
||||
|
@ -1840,7 +1840,8 @@ common::Status InferenceSession::Initialize() {
|
|||
ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_,
|
||||
session_options_.graph_optimization_level,
|
||||
minimal_build_optimization_handling,
|
||||
record_runtime_optimization_produced_op_schema));
|
||||
record_runtime_optimization_produced_op_schema,
|
||||
*session_logger_));
|
||||
|
||||
#ifdef USE_DML
|
||||
const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider);
|
||||
|
@ -2112,7 +2113,7 @@ common::Status InferenceSession::Initialize() {
|
|||
std::vector<TuningResults> tuning_results;
|
||||
bool found_tuning_results = false;
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata(
|
||||
model_metadata_, tuning_results, found_tuning_results));
|
||||
model_metadata_, tuning_results, found_tuning_results, *session_logger_));
|
||||
if (found_tuning_results) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/ false, /*auto_enable*/ true));
|
||||
}
|
||||
|
@ -3233,7 +3234,8 @@ common::Status InferenceSession::AddPredefinedTransformers(
|
|||
GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
MinimalBuildOptimizationHandling minimal_build_optimization_handling,
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const {
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
TransformerLevel level = static_cast<TransformerLevel>(i);
|
||||
|
@ -3245,7 +3247,7 @@ common::Status InferenceSession::AddPredefinedTransformers(
|
|||
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;
|
||||
|
||||
if (use_full_build_optimizations) {
|
||||
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep,
|
||||
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger,
|
||||
optimizers_to_disable_,
|
||||
GetIntraOpThreadPoolToUse(),
|
||||
session_state_->GetMutableBufferedTensors());
|
||||
|
@ -3257,6 +3259,7 @@ common::Status InferenceSession::AddPredefinedTransformers(
|
|||
record_runtime_optimization_produced_op_schema_fn}}
|
||||
: SatApplyContextVariant{SatDirectApplicationContext{}};
|
||||
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
|
||||
logger,
|
||||
optimizers_to_disable_,
|
||||
GetIntraOpThreadPoolToUse(),
|
||||
session_state_->GetMutableBufferedTensors());
|
||||
|
|
|
@ -690,8 +690,9 @@ class InferenceSession {
|
|||
* If we encounter an invalid request, we return an error
|
||||
* back to the user.
|
||||
*/
|
||||
[[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
|
||||
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const;
|
||||
[[nodiscard]] common::Status ValidateAndParseShrinkArenaString(
|
||||
const std::string& ort_device_list,
|
||||
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const;
|
||||
|
||||
/*
|
||||
* Performs the shrinkage of arenas requested to be shrunk by the user
|
||||
|
@ -708,7 +709,8 @@ class InferenceSession {
|
|||
GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
MinimalBuildOptimizationHandling minimal_build_optimization_handling,
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const;
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn,
|
||||
const logging::Logger& logger) const;
|
||||
|
||||
common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format);
|
||||
|
||||
|
|
|
@ -236,7 +236,8 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options
|
|||
|
||||
Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
|
||||
std::vector<TuningResults>& results,
|
||||
bool& key_found) {
|
||||
bool& key_found,
|
||||
const logging::Logger& logger) {
|
||||
results.clear();
|
||||
key_found = false;
|
||||
auto it = metadata.custom_metadata_map.find(kTuningResultsKeys);
|
||||
|
@ -245,7 +246,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met
|
|||
}
|
||||
|
||||
key_found = true;
|
||||
LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model";
|
||||
LOGS(logger, INFO) << "Found tuning results in the model file to be used while loading the model";
|
||||
|
||||
Status status;
|
||||
ORT_TRY {
|
||||
|
|
|
@ -19,7 +19,9 @@ using json = nlohmann::json;
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
namespace inference_session_utils {
|
||||
|
||||
// need this value to be accessible in all builds in order to report error for attempted usage in a minimal build
|
||||
|
@ -60,7 +62,8 @@ class JsonConfigParser {
|
|||
|
||||
Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
|
||||
/*out*/ std::vector<TuningResults>& results,
|
||||
/*out*/ bool& key_found);
|
||||
/*out*/ bool& key_found,
|
||||
const logging::Logger& logger);
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
|
|
|
@ -279,8 +279,9 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) override {
|
||||
return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) override {
|
||||
return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
}
|
||||
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
|
@ -1057,8 +1058,8 @@ struct ProviderHostImpl : ProviderHost {
|
|||
}
|
||||
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) override {
|
||||
return QDQ::GetAllNodeUnits(*graph_viewer);
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) override {
|
||||
return QDQ::GetAllNodeUnits(*graph_viewer, logger);
|
||||
}
|
||||
|
||||
// Model (wrapped)
|
||||
|
|
|
@ -314,7 +314,8 @@ class StandAloneKernelContext : public OpKernelContext {
|
|||
AllocatorPtr allocator_;
|
||||
}; // StandAloneKernelContext
|
||||
|
||||
onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, OrtOpAttr** op_attr) {
|
||||
onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type,
|
||||
OrtOpAttr** op_attr) {
|
||||
auto attr = std::make_unique<ONNX_NAMESPACE::AttributeProto>();
|
||||
onnxruntime::Status status = onnxruntime::Status::OK();
|
||||
attr->set_name(std::string{name});
|
||||
|
@ -410,7 +411,9 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info,
|
|||
|
||||
node_ptr->SetSinceVersion(version);
|
||||
|
||||
auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, &kernel_create_info);
|
||||
auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map,
|
||||
logging::LoggingManager::DefaultLogger(), // no other logger available
|
||||
&kernel_create_info);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
auto& kernel_def = kernel_create_info->kernel_def;
|
||||
|
|
|
@ -252,6 +252,7 @@ class PlannerTest : public ::testing::Test {
|
|||
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg,
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map) {
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
const IExecutionProvider* ep = execution_providers_.Get(*p_node);
|
||||
ASSERT_NE(ep, nullptr);
|
||||
auto info = std::make_unique<OpKernelInfo>(
|
||||
|
@ -261,7 +262,7 @@ class PlannerTest : public ::testing::Test {
|
|||
op_kernel_infos_.push_back(std::move(info));
|
||||
const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{};
|
||||
if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider,
|
||||
kernel_type_str_resolver)) {
|
||||
kernel_type_str_resolver, logger)) {
|
||||
ASSERT_STATUS_OK(reg->Register(
|
||||
KernelCreateInfo(std::make_unique<KernelDef>(kernel_def),
|
||||
[](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
|
||||
|
@ -271,7 +272,7 @@ class PlannerTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
const KernelCreateInfo* kci;
|
||||
ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci));
|
||||
ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, logger, &kci));
|
||||
kernel_create_info_map.insert({p_node->Index(), gsl::not_null<const KernelCreateInfo*>(kci)});
|
||||
}
|
||||
|
||||
|
@ -283,7 +284,8 @@ class PlannerTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
void CreatePlan(const std::vector<const NodeArg*>& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) {
|
||||
void CreatePlan(const std::vector<const NodeArg*>& outer_scope_node_args = {},
|
||||
bool invoke_createPlan_explicityly = true) {
|
||||
state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_));
|
||||
EXPECT_EQ(graph_.Resolve(), Status::OK());
|
||||
|
|
|
@ -831,7 +831,8 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map<s
|
|||
|
||||
bool has_constant_folding = false;
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, *e.get(), {});
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, *e.get(), logger,
|
||||
{});
|
||||
for (auto& transformer : transformers) {
|
||||
if (transformer->Name() == "ConstantFolding") {
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1));
|
||||
|
@ -4704,7 +4705,8 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
|
|||
// Compare results
|
||||
double per_sample_tolerance = 1e-3;
|
||||
double relative_per_sample_tolerance = 0.0;
|
||||
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false);
|
||||
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0],
|
||||
per_sample_tolerance, relative_per_sample_tolerance, false);
|
||||
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
|
||||
}
|
||||
|
||||
|
@ -4713,7 +4715,8 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt
|
|||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
bool has_gelu_approximation = false;
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {});
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(),
|
||||
DefaultLoggingManager().DefaultLogger(), {});
|
||||
for (auto& transformer : transformers) {
|
||||
if (transformer->Name() == "GeluApproximation") {
|
||||
has_gelu_approximation = true;
|
||||
|
@ -4728,7 +4731,8 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) {
|
|||
auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) {
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_ep = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
bool has_double_qdq_remover = false;
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {});
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(),
|
||||
DefaultLoggingManager().DefaultLogger(), {});
|
||||
for (auto& transformer : transformers) {
|
||||
if (transformer->Name() == "DoubleQDQPairsRemover") {
|
||||
has_double_qdq_remover = true;
|
||||
|
|
|
@ -36,9 +36,11 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
std::string l2_transformer = "ConvActivationFusion";
|
||||
InlinedHashSet<std::string> disabled = {l1_rule1, l1_transformer, l2_transformer};
|
||||
CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{});
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep);
|
||||
auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled);
|
||||
auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger);
|
||||
auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger,
|
||||
disabled);
|
||||
|
||||
// check ConstantFolding transformer was removed
|
||||
ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1);
|
||||
|
@ -61,8 +63,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// check that ConvActivationFusion was removed
|
||||
all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep);
|
||||
filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled);
|
||||
all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger);
|
||||
filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger,
|
||||
disabled);
|
||||
ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ namespace test {
|
|||
TEST(OptimizerTest, Basic) {
|
||||
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
const logging::Logger& logger = DefaultLoggingManager().DefaultLogger();
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
constexpr int tensor_dim = 10;
|
||||
|
@ -66,22 +67,21 @@ TEST(OptimizerTest, Basic) {
|
|||
|
||||
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
|
||||
graph.ModelPath(),
|
||||
*cpu_execution_provider.get(),
|
||||
[&graph](const std::string& name) -> bool {
|
||||
return graph.IsSparseInitializer(name);
|
||||
});
|
||||
OptimizerExecutionFrame::Info info(
|
||||
nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(),
|
||||
[&graph](const std::string& name) -> bool {
|
||||
return graph.IsSparseInitializer(name);
|
||||
},
|
||||
logger);
|
||||
#else
|
||||
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
|
||||
graph.ModelPath(),
|
||||
*cpu_execution_provider.get(),
|
||||
[](std::string const&) { return false; });
|
||||
OptimizerExecutionFrame::Info info(
|
||||
nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(),
|
||||
[](std::string const&) { return false; },
|
||||
logger);
|
||||
#endif //! defined(DISABLE_SPARSE_TENSORS)
|
||||
|
||||
std::vector<int> fetch_mlvalue_idxs{info.GetMLValueIndex("out")};
|
||||
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
|
||||
const logging::Logger& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
|
|
|
@ -3928,6 +3928,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) {
|
|||
|
||||
TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
SessionOptions so;
|
||||
// We want to keep the graph un-optimized to prevent QDQ transformer to kick in
|
||||
|
@ -3962,7 +3963,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
|
||||
// Check if SelectorManager get a conv qdq group selection as expected
|
||||
{
|
||||
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer);
|
||||
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger);
|
||||
ASSERT_FALSE(result.empty());
|
||||
const auto& qdq_group = result.at(0);
|
||||
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
|
||||
|
@ -3977,7 +3978,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger);
|
||||
|
||||
// We should get a single QDQ Node unit in the result
|
||||
ASSERT_EQ(1, node_unit_holder.size());
|
||||
|
@ -4045,7 +4046,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
|
||||
// Check SelectorManager will get empty result
|
||||
{
|
||||
const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer);
|
||||
const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer, logger);
|
||||
ASSERT_TRUE(result.empty());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -420,6 +420,7 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
continue;
|
||||
|
||||
bool found = false;
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
for (const auto& ep : execution_providers) {
|
||||
auto provider_type = ep->Type();
|
||||
|
@ -438,7 +439,8 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
}
|
||||
|
||||
// Check the EP has an impl for the node from builtin registry.
|
||||
if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver)) {
|
||||
if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver,
|
||||
logger)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
|
@ -451,6 +453,7 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
std::string_view(kMSInternalNHWCDomain),
|
||||
node.SinceVersion(),
|
||||
type_constraint_map,
|
||||
logger,
|
||||
&kci);
|
||||
if (status.IsOK() && kci != nullptr) {
|
||||
found = true;
|
||||
|
@ -463,7 +466,7 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
std::any_of(custom_registries->cbegin(), custom_registries->cend(),
|
||||
[&](auto reg) {
|
||||
return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(),
|
||||
kernel_type_str_resolver);
|
||||
kernel_type_str_resolver, logger);
|
||||
})) {
|
||||
found = true;
|
||||
break;
|
||||
|
|
|
@ -42,8 +42,9 @@ void KernelComputeTester::Run(std::unordered_set<int> strided_outputs) {
|
|||
}
|
||||
#endif
|
||||
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{domain_, opset_version_}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
{{domain_, opset_version_}}, {}, logger);
|
||||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::unordered_map<std::string, OrtValue> initializer_map;
|
||||
|
@ -89,8 +90,7 @@ void KernelComputeTester::Run(std::unordered_set<int> strided_outputs) {
|
|||
ASSERT_STATUS_OK(graph.Resolve());
|
||||
|
||||
node.SetExecutionProviderType(ep_type);
|
||||
OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type),
|
||||
[](std::string const&) { return false; });
|
||||
OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), [](std::string const&) { return false; }, logger);
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
ASSERT_STATUS_OK(info.TryFindKernel(&node, &kernel_create_info));
|
||||
ASSERT_TRUE(kernel_create_info);
|
||||
|
@ -139,7 +139,7 @@ void KernelComputeTester::Run(std::unordered_set<int> strided_outputs) {
|
|||
#pragma warning(disable : 6387)
|
||||
#endif
|
||||
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs, outputs);
|
||||
OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger);
|
||||
#ifdef _WIN32
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
|
|
@ -51,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) {
|
|||
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
|
||||
gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
|
||||
|
@ -82,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function<void(ModelTestBuilder&)>&
|
|||
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
const auto is_node_supported = [&](const Node& /*node*/) -> bool {
|
||||
return true;
|
||||
|
|
|
@ -758,7 +758,8 @@ Status TrainingSession::AddPredefinedTransformers(
|
|||
GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
MinimalBuildOptimizationHandling minimal_build_optimization_handling,
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const {
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/,
|
||||
const logging::Logger& /*logger*/) const {
|
||||
ORT_RETURN_IF_NOT(
|
||||
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations,
|
||||
"Only applying full build optimizations is supported by TrainingSession.");
|
||||
|
|
|
@ -489,7 +489,8 @@ class TrainingSession : public InferenceSession {
|
|||
GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
MinimalBuildOptimizationHandling minimal_build_optimization_handling,
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override;
|
||||
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
/** Perform auto-diff to add backward graph into the model.
|
||||
@param weights_to_train a set of weights to be training.
|
||||
|
|
|
@ -139,7 +139,8 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
|
|||
|
||||
auto reg = execution_provider->GetKernelRegistry();
|
||||
const KernelCreateInfo* kci;
|
||||
auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci);
|
||||
auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver,
|
||||
DefaultLoggingManager().DefaultLogger(), &kci);
|
||||
if (!st.IsOK()) {
|
||||
// The goal here is unclear. It seems best to leave it to the Session
|
||||
// creation to figure out whether the model can be executed using some
|
||||
|
|
|
@ -23,8 +23,10 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) {
|
|||
InlinedHashSet<std::string> disabled = {l1_rule1, l1_transformer, l2_transformer};
|
||||
CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{});
|
||||
|
||||
auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep);
|
||||
auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled);
|
||||
const auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger);
|
||||
auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger,
|
||||
disabled);
|
||||
|
||||
// check ConstantFolding transformer was removed
|
||||
ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1);
|
||||
|
@ -47,8 +49,8 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) {
|
|||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// check that ConvActivationFusion was removed
|
||||
all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep);
|
||||
filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled);
|
||||
all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger);
|
||||
filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, disabled);
|
||||
ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1);
|
||||
#endif
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче