Merge branch 'main' into Cjian/jdk17-js
# Conflicts: # js/react_native/android/gradle/wrapper/gradle-wrapper.properties
This commit is contained in:
Коммит
2c077e194f
|
@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
|
||||
${MLAS_SRC_DIR}/flashattn.cpp
|
||||
${MLAS_SRC_DIR}/cast.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding.cpp
|
||||
)
|
||||
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
|
@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
|
|||
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
${MLAS_SRC_DIR}/fp16_neon_common.cpp
|
||||
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
|
||||
)
|
||||
|
||||
set(mlas_platform_preprocess_srcs
|
||||
|
@ -367,6 +372,8 @@ else()
|
|||
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
|
||||
|
@ -384,8 +391,9 @@ else()
|
|||
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
|
||||
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/fp16_neon_common.cpp
|
||||
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
|
||||
|
@ -395,8 +403,9 @@ else()
|
|||
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
endif()
|
||||
|
||||
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
|
||||
|
|
|
@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>(Optional) Hardware architecture.</dd>
|
||||
<dt><tt>main_context</tt> : int</dt>
|
||||
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
|
||||
<dt><tt>max_size</tt> : int</dt>
|
||||
<dd>max size in the context. Usage depend on the EP.</dd>
|
||||
<dt><tt>notes</tt> : string</dt>
|
||||
<dd>(Optional) Some notes for the model</dd>
|
||||
<dt><tt>onnx_model_filename</tt> : string</dt>
|
||||
|
|
|
@ -8,6 +8,9 @@
|
|||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
|
||||
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
|
||||
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
|
||||
|
@ -33,6 +36,7 @@ class KernelRegistry {
|
|||
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
|
||||
Status TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
// map of type constraint name to required type
|
||||
|
@ -42,6 +46,7 @@ class KernelRegistry {
|
|||
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
|
||||
Status TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
/**
|
||||
|
@ -61,13 +66,15 @@ class KernelRegistry {
|
|||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
|
||||
ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver) {
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger) {
|
||||
const KernelCreateInfo* info;
|
||||
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
|
||||
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
|
||||
return st.IsOK();
|
||||
}
|
||||
|
||||
|
@ -83,6 +90,7 @@ class KernelRegistry {
|
|||
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
// Check whether the types of inputs/outputs of the given node match the extra
|
||||
|
|
|
@ -53,6 +53,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
TransformerLevel level,
|
||||
const SessionOptions& session_options,
|
||||
const IExecutionProvider& execution_provider /*required by constant folding*/,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
|
||||
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
|
||||
|
@ -84,6 +85,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
const SessionOptions& session_options,
|
||||
const SatApplyContextVariant& apply_context,
|
||||
const IExecutionProvider& cpu_execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
|
||||
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
|
||||
|
|
|
@ -47,8 +47,20 @@ enum COREMLFlags {
|
|||
// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead.
|
||||
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
|
||||
static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat";
|
||||
// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES
|
||||
static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes";
|
||||
static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs";
|
||||
// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property
|
||||
// Core ML segments the model’s compute graph and specializes each segment for the target compute device.
|
||||
// This process can affect the model loading time and the prediction latency.
|
||||
// Use this option to tailor the specialization strategy for your model.
|
||||
static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy";
|
||||
// Profile the Core ML MLComputePlan.
|
||||
// This logs the hardware each operator is dispatched to and the estimated execution time.
|
||||
// Intended for developer usage but provide useful diagnostic information if performance is not as expected.
|
||||
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
|
||||
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
|
||||
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -3667,6 +3667,9 @@ struct OrtApi {
|
|||
* execution provider (typically CPU EP).
|
||||
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
|
||||
* - "1": Enabled.
|
||||
* "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
|
||||
* - "0": Default. Disabled.
|
||||
* - "1": Enabled.
|
||||
*
|
||||
* SNPE supported keys:
|
||||
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
|
||||
|
@ -4612,6 +4615,8 @@ struct OrtApi {
|
|||
* \param[in] num_keys
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.17.
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
|
||||
_In_ OrtSessionOptions* options,
|
||||
|
@ -4629,6 +4634,8 @@ struct OrtApi {
|
|||
* \param[in] num_keys
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.18.
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI,
|
||||
_In_ OrtSessionOptions* options,
|
||||
|
@ -4642,7 +4649,10 @@ struct OrtApi {
|
|||
* \param[in] mem_info OrtMemoryInfo instance
|
||||
* \param[in] count_or_bytes How many bytes is this scratch buffer
|
||||
* \param[out] out A pointer to the scrach buffer
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.18.
|
||||
*/
|
||||
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
|
||||
|
||||
|
@ -4653,6 +4663,8 @@ struct OrtApi {
|
|||
* \param[out] out A pointer to OrtAllocator
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.18.
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
|
||||
|
||||
|
@ -4674,6 +4686,8 @@ struct OrtApi {
|
|||
* \param[in] num_external_initializer_files Number of external files
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.18.
|
||||
*/
|
||||
ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options,
|
||||
_In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names,
|
||||
|
@ -4696,6 +4710,8 @@ struct OrtApi {
|
|||
* OrtApi::ReleaseLoraAdapter.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.20.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtLoraAdapter** out);
|
||||
|
@ -4714,6 +4730,8 @@ struct OrtApi {
|
|||
* OrtApi::ReleaseLoraAdapter.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.20.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtLoraAdapter** out);
|
||||
|
@ -4735,6 +4753,8 @@ struct OrtApi {
|
|||
* \param[in] adapter OrtLoraAdapter instance
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.20.
|
||||
*/
|
||||
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);
|
||||
|
||||
|
@ -4753,6 +4773,8 @@ struct OrtApi {
|
|||
* \param[in] kv_len Number of elements in the keys and values arrays
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.20.
|
||||
*/
|
||||
ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
|
||||
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);
|
||||
|
|
|
@ -198,19 +198,6 @@ module.exports = {
|
|||
'_OrtReleaseTensor',
|
||||
'_OrtRun',
|
||||
'_OrtRunWithBinding',
|
||||
'_OrtTrainingCopyParametersFromBuffer',
|
||||
'_OrtTrainingCopyParametersToBuffer',
|
||||
'_OrtTrainingCreateSession',
|
||||
'_OrtTrainingEvalStep',
|
||||
'_OrtTrainingGetModelInputOutputCount',
|
||||
'_OrtTrainingGetModelInputOutputName',
|
||||
'_OrtTrainingGetParametersSize',
|
||||
'_OrtTrainingLazyResetGrad',
|
||||
'_OrtTrainingLoadCheckpoint',
|
||||
'_OrtTrainingOptimizerStep',
|
||||
'_OrtTrainingReleaseCheckpoint',
|
||||
'_OrtTrainingReleaseSession',
|
||||
'_OrtTrainingRunTrainStep',
|
||||
],
|
||||
},
|
||||
],
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import { InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { TrainingSession } from './training-session.js';
|
||||
|
||||
/**
|
||||
* @ignore
|
||||
|
@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
|
|||
): Promise<SessionHandler.ReturnType>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a handler instance of a training inference session.
|
||||
*
|
||||
* @ignore
|
||||
*/
|
||||
export interface TrainingSessionHandler extends SessionHandler {
|
||||
readonly evalInputNames: readonly string[];
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
lazyResetGrad(): Promise<void>;
|
||||
runTrainStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType>;
|
||||
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
|
||||
runEvalStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType>;
|
||||
|
||||
getParametersSize(trainableOnly: boolean): Promise<number>;
|
||||
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
|
||||
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a backend that provides implementation of model inferencing.
|
||||
*
|
||||
|
@ -84,14 +56,6 @@ export interface Backend {
|
|||
uriOrBuffer: string | Uint8Array,
|
||||
options?: InferenceSession.SessionOptions,
|
||||
): Promise<InferenceSessionHandler>;
|
||||
|
||||
createTrainingSessionHandler?(
|
||||
checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
options: InferenceSession.SessionOptions,
|
||||
): Promise<TrainingSessionHandler>;
|
||||
}
|
||||
|
||||
export { registerBackend } from './backend-impl.js';
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import { env as envImpl } from './env-impl.js';
|
||||
import { TryGetGlobalType } from './type-helper.js';
|
||||
|
||||
export declare namespace Env {
|
||||
export type WasmPathPrefix = string;
|
||||
|
@ -14,7 +15,6 @@ export declare namespace Env {
|
|||
* If not modified, the filename of the .wasm file is:
|
||||
* - `ort-wasm-simd-threaded.wasm` for default build
|
||||
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
|
||||
* - `ort-training-wasm-simd-threaded.wasm` for training build
|
||||
*/
|
||||
wasm?: URL | string;
|
||||
/**
|
||||
|
@ -25,7 +25,6 @@ export declare namespace Env {
|
|||
* If not modified, the filename of the .mjs file is:
|
||||
* - `ort-wasm-simd-threaded.mjs` for default build
|
||||
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
|
||||
* - `ort-training-wasm-simd-threaded.mjs` for training build
|
||||
*/
|
||||
mjs?: URL | string;
|
||||
}
|
||||
|
@ -200,22 +199,16 @@ export declare namespace Env {
|
|||
* value will be the GPU adapter that created by the underlying WebGPU backend.
|
||||
*
|
||||
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
|
||||
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
|
||||
*
|
||||
* see comments on {@link Tensor.GpuBufferType}
|
||||
*/
|
||||
adapter: unknown;
|
||||
adapter: TryGetGlobalType<'GPUAdapter'>;
|
||||
/**
|
||||
* Get the device for WebGPU.
|
||||
*
|
||||
* This property is only available after the first WebGPU inference session is created.
|
||||
*
|
||||
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
|
||||
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
|
||||
*
|
||||
* see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
|
||||
*/
|
||||
readonly device: unknown;
|
||||
readonly device: TryGetGlobalType<'GPUDevice'>;
|
||||
/**
|
||||
* Set or get whether validate input content.
|
||||
*
|
||||
|
|
|
@ -26,4 +26,3 @@ export * from './tensor-factory.js';
|
|||
export * from './trace.js';
|
||||
export * from './onnx-model.js';
|
||||
export * from './onnx-value.js';
|
||||
export * from './training-session.js';
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js';
|
||||
import { OnnxModelOptions } from './onnx-model.js';
|
||||
import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js';
|
||||
import { TryGetGlobalType } from './type-helper.js';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-redeclare */
|
||||
|
||||
|
@ -282,7 +283,7 @@ export declare namespace InferenceSession {
|
|||
extends WebNNExecutionProviderName,
|
||||
Omit<WebNNContextOptions, 'deviceType'>,
|
||||
Required<Pick<WebNNContextOptions, 'deviceType'>> {
|
||||
context: unknown /* MLContext */;
|
||||
context: TryGetGlobalType<'MLContext'>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -291,8 +292,8 @@ export declare namespace InferenceSession {
|
|||
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice
|
||||
*/
|
||||
export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName {
|
||||
context: unknown /* MLContext */;
|
||||
gpuDevice: unknown /* GPUDevice */;
|
||||
context: TryGetGlobalType<'MLContext'>;
|
||||
gpuDevice: TryGetGlobalType<'GPUDevice'>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import { TensorFactory } from './tensor-factory.js';
|
||||
import { Tensor as TensorImpl } from './tensor-impl.js';
|
||||
import { TypedTensorUtils } from './tensor-utils.js';
|
||||
import { TryGetGlobalType } from './type-helper.js';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-redeclare */
|
||||
|
||||
|
@ -131,24 +132,19 @@ export declare namespace Tensor {
|
|||
*/
|
||||
export type TextureDataTypes = 'float32';
|
||||
|
||||
type GpuBufferTypeFallback = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
|
||||
/**
|
||||
* type alias for WebGPU buffer
|
||||
*
|
||||
* The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types"
|
||||
* requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen
|
||||
* carefully according to the TypeScript version being used. This means so far there is not a way to keep every
|
||||
* TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
|
||||
*
|
||||
* for more info see https://github.com/gpuweb/types/issues/127
|
||||
*/
|
||||
export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
|
||||
export type GpuBufferType = TryGetGlobalType<'GPUBuffer', GpuBufferTypeFallback>;
|
||||
|
||||
type MLTensorTypeFallback = { destroy(): void };
|
||||
/**
|
||||
* type alias for WebNN MLTensor
|
||||
*
|
||||
* The specification for WebNN's MLTensor is currently in flux.
|
||||
*/
|
||||
export type MLTensorType = unknown;
|
||||
export type MLTensorType = TryGetGlobalType<'MLTensor', MLTensorTypeFallback>;
|
||||
|
||||
/**
|
||||
* supported data types for constructing a tensor from a WebGPU buffer
|
||||
|
|
|
@ -1,273 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { resolveBackendAndExecutionProviders } from './backend-impl.js';
|
||||
import { SessionHandler, TrainingSessionHandler } from './backend.js';
|
||||
import { InferenceSession as InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { Tensor } from './tensor.js';
|
||||
import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js';
|
||||
|
||||
type SessionOptions = InferenceSession.SessionOptions;
|
||||
type FeedsType = InferenceSession.FeedsType;
|
||||
type FetchesType = InferenceSession.FetchesType;
|
||||
type ReturnType = InferenceSession.ReturnType;
|
||||
type RunOptions = InferenceSession.RunOptions;
|
||||
|
||||
const noBackendErrMsg: string =
|
||||
'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
|
||||
|
||||
export class TrainingSession implements TrainingSessionInterface {
|
||||
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
|
||||
this.handler = handler;
|
||||
this.hasOptimizerModel = hasOptimizerModel;
|
||||
this.hasEvalModel = hasEvalModel;
|
||||
}
|
||||
private handler: TrainingSessionHandler;
|
||||
private hasOptimizerModel: boolean;
|
||||
private hasEvalModel: boolean;
|
||||
|
||||
get trainingInputNames(): readonly string[] {
|
||||
return this.handler.inputNames;
|
||||
}
|
||||
get trainingOutputNames(): readonly string[] {
|
||||
return this.handler.outputNames;
|
||||
}
|
||||
|
||||
get evalInputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalInputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
get evalOutputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalOutputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
static async create(
|
||||
trainingOptions: TrainingSessionCreateOptions,
|
||||
sessionOptions?: SessionOptions,
|
||||
): Promise<TrainingSession> {
|
||||
const evalModel: string | Uint8Array = trainingOptions.evalModel || '';
|
||||
const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || '';
|
||||
const options: SessionOptions = sessionOptions || {};
|
||||
|
||||
// resolve backend, update session options with validated EPs, and create session handler
|
||||
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
|
||||
if (backend.createTrainingSessionHandler) {
|
||||
const handler = await backend.createTrainingSessionHandler(
|
||||
trainingOptions.checkpointState,
|
||||
trainingOptions.trainModel,
|
||||
evalModel,
|
||||
optimizerModel,
|
||||
optionsWithValidatedEPs,
|
||||
);
|
||||
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
|
||||
} else {
|
||||
throw new Error(noBackendErrMsg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
|
||||
* the given parameters to SessionHandler.FetchesType and RunOptions.
|
||||
*
|
||||
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
|
||||
* names.
|
||||
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
|
||||
* names.
|
||||
* @param feeds the required input
|
||||
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
|
||||
* @param arg2 optional RunOptions object.
|
||||
* @returns
|
||||
*/
|
||||
typeNarrowingForRunStep(
|
||||
inputNames: readonly string[],
|
||||
outputNames: readonly string[],
|
||||
feeds: FeedsType,
|
||||
arg1?: FetchesType | RunOptions,
|
||||
arg2?: RunOptions,
|
||||
): [SessionHandler.FetchesType, RunOptions] {
|
||||
const fetches: { [name: string]: OnnxValue | null } = {};
|
||||
let options: RunOptions = {};
|
||||
// check inputs
|
||||
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
|
||||
throw new TypeError(
|
||||
"'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.",
|
||||
);
|
||||
}
|
||||
|
||||
let isFetchesEmpty = true;
|
||||
// determine which override is being used
|
||||
if (typeof arg1 === 'object') {
|
||||
if (arg1 === null) {
|
||||
throw new TypeError('Unexpected argument[1]: cannot be null.');
|
||||
}
|
||||
if (arg1 instanceof Tensor) {
|
||||
throw new TypeError("'fetches' cannot be a Tensor");
|
||||
}
|
||||
|
||||
if (Array.isArray(arg1)) {
|
||||
if (arg1.length === 0) {
|
||||
throw new TypeError("'fetches' cannot be an empty array.");
|
||||
}
|
||||
isFetchesEmpty = false;
|
||||
// output names
|
||||
for (const name of arg1) {
|
||||
if (typeof name !== 'string') {
|
||||
throw new TypeError("'fetches' must be a string array or an object.");
|
||||
}
|
||||
if (outputNames.indexOf(name) === -1) {
|
||||
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
|
||||
}
|
||||
fetches[name] = null;
|
||||
}
|
||||
|
||||
if (typeof arg2 === 'object' && arg2 !== null) {
|
||||
options = arg2;
|
||||
} else if (typeof arg2 !== 'undefined') {
|
||||
throw new TypeError("'options' must be an object.");
|
||||
}
|
||||
} else {
|
||||
// decide whether arg1 is fetches or options
|
||||
// if any output name is present and its value is valid OnnxValue, we consider it fetches
|
||||
let isFetches = false;
|
||||
const arg1Keys = Object.getOwnPropertyNames(arg1);
|
||||
for (const name of outputNames) {
|
||||
if (arg1Keys.indexOf(name) !== -1) {
|
||||
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
|
||||
if (v === null || v instanceof Tensor) {
|
||||
isFetches = true;
|
||||
isFetchesEmpty = false;
|
||||
fetches[name] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isFetches) {
|
||||
if (typeof arg2 === 'object' && arg2 !== null) {
|
||||
options = arg2;
|
||||
} else if (typeof arg2 !== 'undefined') {
|
||||
throw new TypeError("'options' must be an object.");
|
||||
}
|
||||
} else {
|
||||
options = arg1 as RunOptions;
|
||||
}
|
||||
}
|
||||
} else if (typeof arg1 !== 'undefined') {
|
||||
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
|
||||
}
|
||||
|
||||
// check if all inputs are in feed
|
||||
for (const name of inputNames) {
|
||||
if (typeof feeds[name] === 'undefined') {
|
||||
throw new Error(`input '${name}' is missing in 'feeds'.`);
|
||||
}
|
||||
}
|
||||
|
||||
// if no fetches is specified, we use the full output names list
|
||||
if (isFetchesEmpty) {
|
||||
for (const name of outputNames) {
|
||||
fetches[name] = null;
|
||||
}
|
||||
}
|
||||
|
||||
return [fetches, options];
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
|
||||
* and changes it into a map of Tensors.
|
||||
*
|
||||
* @param results
|
||||
* @returns
|
||||
*/
|
||||
convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
|
||||
const returnValue: { [name: string]: OnnxValue } = {};
|
||||
for (const key in results) {
|
||||
if (Object.hasOwnProperty.call(results, key)) {
|
||||
const result = results[key];
|
||||
if (result instanceof Tensor) {
|
||||
returnValue[key] = result;
|
||||
} else {
|
||||
returnValue[key] = new Tensor(result.type, result.data, result.dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
async lazyResetGrad(): Promise<void> {
|
||||
await this.handler.lazyResetGrad();
|
||||
}
|
||||
|
||||
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
|
||||
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
|
||||
async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
const [fetches, options] = this.typeNarrowingForRunStep(
|
||||
this.trainingInputNames,
|
||||
this.trainingOutputNames,
|
||||
feeds,
|
||||
arg1,
|
||||
arg2,
|
||||
);
|
||||
const results = await this.handler.runTrainStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
}
|
||||
|
||||
async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise<void> {
|
||||
if (this.hasOptimizerModel) {
|
||||
await this.handler.runOptimizerStep(options || {});
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no OptimizerModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise<ReturnType>;
|
||||
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise<ReturnType>;
|
||||
async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
if (this.hasEvalModel) {
|
||||
const [fetches, options] = this.typeNarrowingForRunStep(
|
||||
this.evalInputNames,
|
||||
this.evalOutputNames,
|
||||
feeds,
|
||||
arg1,
|
||||
arg2,
|
||||
);
|
||||
const results = await this.handler.runEvalStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no EvalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
async getParametersSize(trainableOnly = true): Promise<number> {
|
||||
return this.handler.getParametersSize(trainableOnly);
|
||||
}
|
||||
|
||||
async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise<void> {
|
||||
const paramsSize = await this.getParametersSize(trainableOnly);
|
||||
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
|
||||
// of parameters
|
||||
if (array.length !== 4 * paramsSize) {
|
||||
throw new Error(
|
||||
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
|
||||
'the model. Please use getParametersSize method to check.',
|
||||
);
|
||||
}
|
||||
return this.handler.loadParametersBuffer(array, trainableOnly);
|
||||
}
|
||||
|
||||
async getContiguousParameters(trainableOnly = true): Promise<OnnxValue> {
|
||||
return this.handler.getContiguousParameters(trainableOnly);
|
||||
}
|
||||
|
||||
async release(): Promise<void> {
|
||||
return this.handler.dispose();
|
||||
}
|
||||
}
|
|
@ -1,206 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-redeclare */
|
||||
|
||||
export declare namespace TrainingSession {
|
||||
/**
|
||||
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
|
||||
*/
|
||||
type UriOrBuffer = string | Uint8Array;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a runtime instance of an ONNX training session,
|
||||
* which contains a model that can be trained, and, optionally,
|
||||
* an eval and optimizer model.
|
||||
*/
|
||||
export interface TrainingSession {
|
||||
// #region run()
|
||||
|
||||
/**
|
||||
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
|
||||
* runOptimizerStep.
|
||||
*/
|
||||
lazyResetGrad(): Promise<void>;
|
||||
|
||||
/**
|
||||
* Run TrainStep asynchronously with the given feeds and options.
|
||||
*
|
||||
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
|
||||
detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model training.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
|
||||
*/
|
||||
runTrainStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Run a single train step with the given inputs and options.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model training.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runTrainStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
|
||||
*
|
||||
* @param options - Optional. A set of options that controls the behavior of model optimizing.
|
||||
*/
|
||||
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region copy parameters
|
||||
|
||||
/**
|
||||
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
|
||||
* the parameters) elements of all the parameters in the training state.
|
||||
*
|
||||
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
|
||||
*/
|
||||
getParametersSize(trainableOnly: boolean): Promise<number>;
|
||||
|
||||
/**
|
||||
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
|
||||
* parameters of type Float32.
|
||||
*
|
||||
* @param buffer - A Uint8Array representation of Float32 parameters.
|
||||
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
|
||||
*/
|
||||
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
|
||||
|
||||
/**
|
||||
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
|
||||
* Currently, only supporting models with parameters of type Float32.
|
||||
*
|
||||
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
|
||||
* for which requires_grad is set to true. Default value is true.
|
||||
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
|
||||
*/
|
||||
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
||||
// #endregion
|
||||
|
||||
// #region release()
|
||||
|
||||
/**
|
||||
* Release the inference session and the underlying resources.
|
||||
*/
|
||||
release(): Promise<void>;
|
||||
// #endregion
|
||||
|
||||
// #region metadata
|
||||
|
||||
/**
|
||||
* Get input names of the loaded training model.
|
||||
*/
|
||||
readonly trainingInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded training model.
|
||||
*/
|
||||
readonly trainingOutputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
// #endregion
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
|
||||
*/
|
||||
export interface TrainingSessionCreateOptions {
|
||||
/**
|
||||
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
|
||||
*/
|
||||
checkpointState: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* URI or buffer for the .onnx training file.
|
||||
*/
|
||||
trainModel: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* Optional. URI or buffer for the .onnx optimizer model file.
|
||||
*/
|
||||
optimizerModel?: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* Optional. URI or buffer for the .onnx eval model file.
|
||||
*/
|
||||
evalModel?: TrainingSession.UriOrBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines method overload possibilities for creating a TrainingSession.
|
||||
*/
|
||||
export interface TrainingSessionFactory {
|
||||
// #region create()
|
||||
|
||||
/**
|
||||
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
|
||||
*
|
||||
* @param trainingOptions specify models and checkpoints to load into the Training Session
|
||||
* @param sessionOptions specify configuration for training session behavior
|
||||
*
|
||||
* @returns Promise that resolves to a TrainingSession object
|
||||
*/
|
||||
create(
|
||||
trainingOptions: TrainingSessionCreateOptions,
|
||||
sessionOptions?: InferenceSession.SessionOptions,
|
||||
): Promise<TrainingSession>;
|
||||
|
||||
// #endregion
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/**
|
||||
* A helper type to get certain types if they are declared in global scope.
|
||||
*
|
||||
* For example, if you installed "@webgpu/types" as a dev dependency, then `TryGetTypeIfDeclared<'GPUDevice'>` will
|
||||
* be type `GPUDevice`, otherwise it will be type `unknown`.
|
||||
*
|
||||
*
|
||||
* We don't want to introduce "@webgpu/types" as a dependency of this package because:
|
||||
*
|
||||
* (1) For JavaScript users, it's not needed. For TypeScript users, they can install it as dev dependency themselves.
|
||||
*
|
||||
* (2) because "@webgpu/types" requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its
|
||||
* version need to be chosen carefully according to the TypeScript version being used. This means so far there is not a
|
||||
* way to keep every TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
|
||||
*
|
||||
* for more info see https://github.com/gpuweb/types/issues/127
|
||||
*
|
||||
* Update (2024-08-07): The reason (2) may be no longer valid. Most people should be using TypeScript >= 5.1 by now.
|
||||
* However, we are still not sure whether introducing "@webgpu/types" as direct dependency is a good idea. We find this
|
||||
* type helper is useful for TypeScript users.
|
||||
*
|
||||
* @ignore
|
||||
*/
|
||||
export type TryGetGlobalType<Name extends string, Fallback = unknown> = typeof globalThis extends {
|
||||
[k in Name]: { prototype: infer T };
|
||||
}
|
||||
? T
|
||||
: Fallback;
|
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"entryPoints": ["lib/index.ts"],
|
||||
"excludeInternal": true,
|
||||
"intentionallyNotExported": ["TryGetGlobalType"],
|
||||
"name": "ONNX Runtime JavaScript API",
|
||||
"readme": "none",
|
||||
"cleanOutputDir": true
|
||||
|
|
|
@ -276,12 +276,12 @@
|
|||
"dev": true
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.6.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
|
||||
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
|
||||
"version": "1.7.9",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
|
||||
"integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.0",
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.0",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
|
@ -1581,12 +1581,12 @@
|
|||
"dev": true
|
||||
},
|
||||
"axios": {
|
||||
"version": "1.6.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
|
||||
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
|
||||
"version": "1.7.9",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
|
||||
"integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"follow-redirects": "^1.15.0",
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.0",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
|
|
Двоичный файл не отображается.
|
@ -2,5 +2,7 @@ distributionBase=GRADLE_USER_HOME
|
|||
distributionPath=wrapper/dists
|
||||
distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip
|
||||
networkTimeout=10000
|
||||
validateDistributionUrl=true
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
|
|
@ -55,7 +55,7 @@
|
|||
# Darwin, MinGW, and NonStop.
|
||||
#
|
||||
# (3) This script is generated from the Groovy template
|
||||
# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
|
||||
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
|
||||
# within the Gradle project.
|
||||
#
|
||||
# You can find Gradle at https://github.com/gradle/gradle/.
|
||||
|
@ -80,13 +80,11 @@ do
|
|||
esac
|
||||
done
|
||||
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
|
||||
|
||||
APP_NAME="Gradle"
|
||||
# This is normally unused
|
||||
# shellcheck disable=SC2034
|
||||
APP_BASE_NAME=${0##*/}
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD=maximum
|
||||
|
@ -133,22 +131,29 @@ location of your Java installation."
|
|||
fi
|
||||
else
|
||||
JAVACMD=java
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
if ! command -v java >/dev/null 2>&1
|
||||
then
|
||||
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
||||
case $MAX_FD in #(
|
||||
max*)
|
||||
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
MAX_FD=$( ulimit -H -n ) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
esac
|
||||
case $MAX_FD in #(
|
||||
'' | soft) :;; #(
|
||||
*)
|
||||
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
esac
|
||||
|
@ -193,11 +198,15 @@ if "$cygwin" || "$msys" ; then
|
|||
done
|
||||
fi
|
||||
|
||||
# Collect all arguments for the java command;
|
||||
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
|
||||
# shell script including quotes and variable substitutions, so put them in
|
||||
# double quotes to make sure that they get re-expanded; and
|
||||
# * put everything else in single quotes, so that it's not re-expanded.
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Collect all arguments for the java command:
|
||||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
|
||||
# and any embedded shellness will be escaped.
|
||||
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
|
||||
# treated as '${Hostname}' itself on the command line.
|
||||
|
||||
set -- \
|
||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||
|
|
|
@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal
|
|||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%"=="" set DIRNAME=.
|
||||
@rem This is normally unused
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
|
@ -42,11 +43,11 @@ set JAVA_EXE=java.exe
|
|||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if %ERRORLEVEL% equ 0 goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
echo. 1>&2
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2
|
||||
echo. 1>&2
|
||||
echo Please set the JAVA_HOME variable in your environment to match the 1>&2
|
||||
echo location of your Java installation. 1>&2
|
||||
|
||||
goto fail
|
||||
|
||||
|
@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
|||
|
||||
if exist "%JAVA_EXE%" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
echo. 1>&2
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2
|
||||
echo. 1>&2
|
||||
echo Please set the JAVA_HOME variable in your environment to match the 1>&2
|
||||
echo location of your Java installation. 1>&2
|
||||
|
||||
goto fail
|
||||
|
||||
|
|
|
@ -487,7 +487,7 @@ export const prepareInputOutputTensor = (
|
|||
}
|
||||
|
||||
if (location === 'gpu-buffer') {
|
||||
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
|
||||
const gpuBuffer = tensor[2].gpuBuffer;
|
||||
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
|
||||
|
||||
const registerBuffer = wasm.jsepRegisterBuffer;
|
||||
|
|
|
@ -861,9 +861,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
"version": "6.0.5",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
|
||||
"integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
|
||||
"version": "6.0.6",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
|
||||
"integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"nice-try": "^1.0.4",
|
||||
|
@ -4312,9 +4312,9 @@
|
|||
}
|
||||
},
|
||||
"cross-spawn": {
|
||||
"version": "6.0.5",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
|
||||
"integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
|
||||
"version": "6.0.6",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
|
||||
"integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"nice-try": "^1.0.4",
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "contrib_ops/cpu/bert/rotary_embedding.h"
|
||||
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
|
||||
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
using onnxruntime::concurrency::ThreadPool;
|
||||
|
@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
|
|||
const T* cos_data = cos_cache + cache_offset;
|
||||
const T* sin_data = sin_cache + cache_offset;
|
||||
|
||||
int cache_idx = 0;
|
||||
bool sign = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = i & 1;
|
||||
j = sign ? i - 1 : i + 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i >= half_rotary_emb_dim);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
|
||||
float input_data_j = static_cast<float>(input_data[j]);
|
||||
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
|
||||
if (sign) {
|
||||
output_data_i += input_data_j * sin_data_cache_idx;
|
||||
} else {
|
||||
output_data_i -= input_data_j * sin_data_cache_idx;
|
||||
}
|
||||
output_data[i] = static_cast<T>(output_data_i);
|
||||
}
|
||||
for (int i = rotary_emb_dim; i < head_size; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
|
||||
|
||||
if (rotary_emb_dim < head_size) {
|
||||
std::memcpy(output_data + rotary_emb_dim,
|
||||
input_data + rotary_emb_dim,
|
||||
(head_size - rotary_emb_dim) * sizeof(T));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
@ -31,6 +31,7 @@ Subgraph::Subgraph(
|
|||
allocator_(nullptr),
|
||||
is_output_float16_(false) {
|
||||
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
|
||||
used_implicit_inputs = std::vector<bool>(num_implicit_inputs, true);
|
||||
|
||||
auto& subgraph_inputs = subgraph.GetInputs();
|
||||
auto& subgraph_outputs = subgraph.GetOutputs();
|
||||
|
@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state,
|
|||
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
|
||||
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
|
||||
|
||||
for (auto& entry : node.ImplicitInputDefs()) {
|
||||
feed_names.push_back(entry->Name());
|
||||
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
|
||||
|
||||
const auto& implicit_input_defs = node.ImplicitInputDefs();
|
||||
for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) {
|
||||
const auto* entry = implicit_input_defs[i];
|
||||
int idx;
|
||||
if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) {
|
||||
feed_names.push_back(entry->Name());
|
||||
} else {
|
||||
--num_implicit_inputs;
|
||||
used_implicit_inputs[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
InlinedVector<OrtDevice> feed_locations;
|
||||
|
|
|
@ -31,6 +31,7 @@ class Subgraph {
|
|||
const GraphViewer& subgraph; // The subgraph
|
||||
|
||||
int num_implicit_inputs;
|
||||
std::vector<bool> used_implicit_inputs;
|
||||
|
||||
int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
|
||||
int num_subgraph_outputs; // Same as subgraph_output_names.size()
|
||||
|
|
|
@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
|
|||
}
|
||||
|
||||
// Pass through implicit inputs.
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
decoder_feeds.push_back(*entry);
|
||||
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
|
||||
const auto* entry = implicit_inputs[i];
|
||||
if (used_implicit_inputs[i]) {
|
||||
decoder_feeds.push_back(*entry);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
|
|||
pinned_allocator,
|
||||
location));
|
||||
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
feeds.push_back(*entry);
|
||||
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
|
||||
const auto* entry = implicit_inputs[i];
|
||||
if (used_implicit_inputs[i]) {
|
||||
feeds.push_back(*entry);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -138,7 +138,8 @@ class PlannerImpl {
|
|||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
|
||||
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan,
|
||||
const logging::Logger& logger)
|
||||
: context_(&context),
|
||||
plan_(plan),
|
||||
parent_node_(parent_node),
|
||||
|
@ -148,14 +149,15 @@ class PlannerImpl {
|
|||
kernel_create_info_map_(kernel_create_info_map),
|
||||
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
|
||||
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map) {}
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map),
|
||||
logger_(logger) {
|
||||
}
|
||||
|
||||
Status CreatePlan(
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
const IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
#endif
|
||||
const PathString& partition_config_file,
|
||||
const logging::Logger& logger);
|
||||
const PathString& partition_config_file);
|
||||
|
||||
private:
|
||||
gsl::not_null<const ISequentialPlannerContext*> context_;
|
||||
|
@ -183,6 +185,12 @@ class PlannerImpl {
|
|||
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
|
||||
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;
|
||||
|
||||
// logger_ is not currently used in a minimal build
|
||||
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
[[maybe_unused]]
|
||||
#endif
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
|
||||
struct OrtValueInfo {
|
||||
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
|
||||
|
@ -213,6 +221,7 @@ class PlannerImpl {
|
|||
FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point)
|
||||
: ml_value(ort_value), deallocate_point(dealloc_point) {}
|
||||
};
|
||||
|
||||
// freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when
|
||||
// they became free (more recently freed earlier in the list).
|
||||
std::list<FreeBufferInfo> freelist_;
|
||||
|
@ -225,7 +234,8 @@ class PlannerImpl {
|
|||
}
|
||||
|
||||
int& UseCount(OrtValueIndex n) {
|
||||
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size());
|
||||
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(),
|
||||
"invalid value index: ", n, " against size ", ort_value_info_.size());
|
||||
return ort_value_info_[n].usecount;
|
||||
}
|
||||
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }
|
||||
|
@ -335,9 +345,9 @@ class PlannerImpl {
|
|||
// we cannot.
|
||||
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
*reusable_input = Index(p_input_arg->Name());
|
||||
|
@ -361,9 +371,9 @@ class PlannerImpl {
|
|||
// we cannot.
|
||||
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
|
||||
<< producer_node->Name() << " which has external outputs. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
*reusable_input = Index(p_input_arg->Name());
|
||||
|
@ -397,8 +407,8 @@ class PlannerImpl {
|
|||
}
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -448,8 +458,8 @@ class PlannerImpl {
|
|||
return true;
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs.";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs.";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1198,9 +1208,9 @@ class PlannerImpl {
|
|||
// Otherwise, we cannot reuse the buffer.
|
||||
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1241,9 +1251,9 @@ class PlannerImpl {
|
|||
// Otherwise, we cannot reuse the buffer.
|
||||
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
|
||||
if (producer_node && HasExternalOutputs(*producer_node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
|
||||
<< producer_node->Name() << " which has external outputs is reused. "
|
||||
<< "Be cautious the reuse MUST be a read-only usage.";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1290,8 +1300,8 @@ class PlannerImpl {
|
|||
}
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
|
||||
<< producer_node->Name() << " as it has external outputs";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1869,8 +1879,7 @@ class PlannerImpl {
|
|||
}
|
||||
|
||||
#ifndef ORT_ENABLE_STREAM
|
||||
void PartitionIntoStreams(const logging::Logger& /*logger*/,
|
||||
const ExecutionProviders& /*execution_providers*/,
|
||||
void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/,
|
||||
const PathString& /*partition_config_file*/) {
|
||||
if (graph_viewer_.NumberOfNodes() > 0) {
|
||||
stream_nodes_.push_back({});
|
||||
|
@ -1915,11 +1924,11 @@ class PlannerImpl {
|
|||
|
||||
#else
|
||||
|
||||
void
|
||||
PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
|
||||
const PathString& partition_config_file) {
|
||||
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
|
||||
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
|
||||
void PartitionIntoStreams(const ExecutionProviders& execution_providers,
|
||||
const PathString& partition_config_file) {
|
||||
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file);
|
||||
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_,
|
||||
context_->GetExecutionOrder());
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
plan_.node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
|
||||
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
|
||||
|
@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan(
|
|||
#ifdef ORT_ENABLE_STREAM
|
||||
const IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
#endif
|
||||
const PathString& partition_config_file,
|
||||
const logging::Logger& logger) {
|
||||
const PathString& partition_config_file) {
|
||||
// 1. partition graph into streams
|
||||
PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file);
|
||||
PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file);
|
||||
|
||||
// 2. initialize the plan based on stream partition result
|
||||
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
|
||||
|
@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan(
|
|||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
|
||||
kernel_create_info_map, subgraphs_kernel_create_info_maps,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map, context, *plan);
|
||||
ort_value_name_idx_map, context, *plan, logger);
|
||||
|
||||
return planner.CreatePlan(
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
stream_handle_registry,
|
||||
#endif
|
||||
partition_config_file,
|
||||
logger);
|
||||
partition_config_file);
|
||||
}
|
||||
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
|
|
|
@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) {
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) {
|
||||
// automatic conversion from const std::vector&
|
||||
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
|
||||
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex());
|
||||
|
@ -83,7 +84,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
|
|||
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
|
||||
for (auto& consumer_node : consumer_nodes) {
|
||||
candidates.push(consumer_node->Index());
|
||||
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
|
||||
LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -159,9 +160,9 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
|
|||
|
||||
if (place_in_cpu) {
|
||||
cpu_nodes.insert(cur);
|
||||
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
|
||||
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
|
||||
<< " capable of executing this node";
|
||||
LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
|
||||
<< " because the CPU execution path is deemed faster than overhead involved with execution "
|
||||
"on other EPs capable of executing this node";
|
||||
for (auto* output : node->OutputDefs()) {
|
||||
cpu_output_args.insert(output);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
|
||||
/**
|
||||
Returns a list of nodes that are preferred on CPU.
|
||||
|
@ -19,6 +22,7 @@ namespace onnxruntime {
|
|||
*/
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep,
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
||||
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) {
|
||||
auto& current_ep = params.current_ep.get();
|
||||
const auto& ep_type = current_ep.Type();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) {
|
||||
LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
|
||||
LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
|
||||
"the layout transformer.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
|||
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
|
||||
const KernelLookup kernel_lookup{ep_type,
|
||||
kernel_registries_for_ep,
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver()};
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver(),
|
||||
logger};
|
||||
|
||||
auto& graph = params.graph.get();
|
||||
auto& capabilities = params.capabilities.get();
|
||||
|
@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
|
|||
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
|
||||
const KernelRegistryManager& kernel_registry_mgr,
|
||||
const IExecutionProvider& current_ep,
|
||||
const logging::Logger& logger,
|
||||
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
|
||||
const auto& ep_type = current_ep.Type();
|
||||
|
||||
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
|
||||
const KernelLookup kernel_lookup{ep_type,
|
||||
kernel_registries_for_ep,
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver()};
|
||||
kernel_registry_mgr.GetKernelTypeStrResolver(),
|
||||
logger};
|
||||
|
||||
// TODO: Provide EP with a capability to look inside the functions.
|
||||
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
|
||||
|
@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id,
|
||||
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
|
||||
const layout_transformation::DebugGraphFn& debug_graph_fn) {
|
||||
const layout_transformation::DebugGraphFn& debug_graph_fn,
|
||||
const logging::Logger& logger) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
if (graph.NumberOfNodes() == 0) {
|
||||
|
@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
// we pass through the FuncManager from the top level graph
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
|
||||
transform_layout_fn, debug_graph_fn));
|
||||
transform_layout_fn, debug_graph_fn, logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
std::cref(transform_layout_fn),
|
||||
std::cref(debug_graph_fn)};
|
||||
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
|
||||
if (n != nullptr) {
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
|
||||
nodes_to_compile.push_back(n);
|
||||
capabilities_to_compile.push_back(std::move(capability));
|
||||
} else {
|
||||
|
@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
|
|||
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
|
||||
const KernelRegistryManager& kernel_registry_mgr,
|
||||
Graph& graph,
|
||||
const logging::Logger& logger,
|
||||
InlinedHashSet<std::string>& not_inlined,
|
||||
size_t& inlined_count) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
|
@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
|||
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
|
||||
kernel_registry_mgr,
|
||||
*subgraph,
|
||||
logger,
|
||||
not_inlined,
|
||||
inlined_count));
|
||||
}
|
||||
|
@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
|||
InlinedHashSet<NodeIndex> claimed_by_ep;
|
||||
for (const auto& ep : execution_providers) {
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger,
|
||||
capabilities));
|
||||
for (auto& capability : capabilities) {
|
||||
const auto& nodes = capability->sub_graph->nodes;
|
||||
if (nodes.size() == 1) {
|
||||
|
@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
|
|||
|
||||
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
|
||||
const ExecutionProviders& execution_providers,
|
||||
KernelRegistryManager& kernel_registry_manager) {
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const logging::Logger& logger) {
|
||||
bool modified_graph = false;
|
||||
|
||||
auto& graph = partition_params.graph.get();
|
||||
|
@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
|
|||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager,
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id,
|
||||
transform_layout_function,
|
||||
partition_params.debug_graph_fn));
|
||||
partition_params.debug_graph_fn,
|
||||
logger));
|
||||
}
|
||||
|
||||
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
|
||||
|
@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
|
|||
|
||||
static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
IExecutionProvider& current_ep) {
|
||||
IExecutionProvider& current_ep,
|
||||
const logging::Logger& logger) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
auto& graph = partition_params.graph.get();
|
||||
|
@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
auto& subgraph = *entry.second;
|
||||
PartitionParams subgraph_partition_params = partition_params;
|
||||
subgraph_partition_params.graph = std::ref(subgraph);
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep,
|
||||
logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
};
|
||||
// clang-format on
|
||||
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
|
|||
// Simplified partitioning where custom EPs may produce compiled nodes.
|
||||
static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
|
||||
const ExecutionProviders& execution_providers,
|
||||
KernelRegistryManager& kernel_registry_manager) {
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const logging::Logger& logger) {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : execution_providers) {
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
|
|||
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
|
||||
kernel_registry_manager,
|
||||
graph,
|
||||
logger,
|
||||
not_inlined,
|
||||
inlined_count));
|
||||
|
||||
|
@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
|||
|
||||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
|
||||
providers_, kernel_registry_mgr_));
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
|
||||
|
||||
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
|
||||
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
|
@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
|
||||
#endif //! defined(ORT_MINIMAL_BUILD)
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params,
|
||||
providers_, kernel_registry_mgr_));
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
|
|||
public:
|
||||
KernelLookup(ProviderType provider_type,
|
||||
gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver)
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger)
|
||||
: provider_type_{provider_type},
|
||||
kernel_registries_{kernel_registries},
|
||||
kernel_type_str_resolver_{kernel_type_str_resolver} {
|
||||
kernel_type_str_resolver_{kernel_type_str_resolver},
|
||||
logger_{logger} {
|
||||
ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified.");
|
||||
}
|
||||
|
||||
const KernelCreateInfo* LookUpKernel(const Node& node) const override {
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
for (const auto& registry : kernel_registries_) {
|
||||
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_,
|
||||
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_,
|
||||
&kernel_create_info);
|
||||
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
|
||||
return kernel_create_info;
|
||||
|
@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
|
|||
ProviderType provider_type_;
|
||||
const gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries_;
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver_;
|
||||
const logging::Logger& logger_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
const auto& node_provider = node.GetExecutionProviderType();
|
||||
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
|
||||
|
@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
std::ostream_iterator<std::string>(oss, "\n"));
|
||||
oss << ")";
|
||||
|
||||
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
|
||||
}
|
||||
|
||||
|
@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
|
||||
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out);
|
||||
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out);
|
||||
}
|
||||
|
||||
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
|
||||
const TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out);
|
||||
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out);
|
||||
}
|
||||
|
||||
static bool KernelDefCompatible(int version, const KernelDef& kernel_def,
|
||||
|
@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
|
|||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const logging::Logger& logger,
|
||||
const KernelCreateInfo** out) const {
|
||||
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider));
|
||||
if (out) *out = nullptr;
|
||||
|
@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
|
|||
std::ostream_iterator<std::string>(oss, "\n"));
|
||||
oss << ")";
|
||||
|
||||
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistr
|
|||
}
|
||||
#endif
|
||||
|
||||
Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
||||
Status KernelRegistryManager::SearchKernelRegistry(const Node& node, const logging::Logger& logger,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const {
|
||||
Status status;
|
||||
|
||||
|
@ -82,7 +82,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
}
|
||||
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
|
||||
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
}
|
||||
|
||||
if (p != nullptr) {
|
||||
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
|
||||
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
|
||||
}
|
||||
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r,
|
||||
const Node& node,
|
||||
const std::string& provider_type,
|
||||
const logging::Logger& logger) {
|
||||
const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
|
||||
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
|
||||
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver());
|
||||
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(),
|
||||
logger);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -67,13 +67,14 @@ class KernelRegistryManager {
|
|||
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status SearchKernelRegistry(const Node& node,
|
||||
Status SearchKernelRegistry(const Node& node, const logging::Logger& logger,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
||||
/**
|
||||
* Whether this node can be run on this provider
|
||||
*/
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status CreateKernel(const Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
|
|
|
@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
|
|||
bool saving_ort_format) {
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
|
||||
if (!status.IsOK() && saving_ort_format) {
|
||||
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
|
||||
// in that case we assigned the node to that EP but do not compile it into a fused node.
|
||||
|
@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
|
|||
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
|
||||
// if that's not possible for some reason we can fallback to the CPU EP implementation.
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
|
|
@ -3335,6 +3335,11 @@ void RegisterContribSchemas() {
|
|||
AttributeProto::STRING,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
|
||||
.Attr(
|
||||
"max_size",
|
||||
"max size in the context. Usage depend on the EP.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(
|
||||
0,
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "core/graph/constants.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/shape_inference_functions.h"
|
||||
#include "onnx/onnx-ml.pb.h" // ?
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from
|
||||
// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build
|
||||
|
@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx);
|
|||
void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape,
|
||||
int input1Idx, int input2Idx);
|
||||
namespace defs::math::utils {
|
||||
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
|
||||
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
|
||||
}
|
||||
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
}
|
||||
}
|
||||
|
||||
if (all_lengths_known) {
|
||||
output_shape->mutable_dim(axis)->set_dim_value(total_length);
|
||||
}
|
||||
}));
|
||||
if (all_lengths_known) {
|
||||
output_shape->mutable_dim(axis)->set_dim_value(total_length);
|
||||
}
|
||||
}));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema()
|
||||
.SetDoc("Return elements, either from X or Y, depending on condition.")
|
||||
|
@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.",
|
||||
AttributeProto::INT, OPTIONAL_VALUE)
|
||||
.Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
|
||||
.Attr("past_present_share_buffer",
|
||||
"Corresponding past and present are same tensor, its shape is "
|
||||
"(2, batch_size, num_heads, max_sequence_length, head_size)",
|
||||
AttributeProto::INT, OPTIONAL_VALUE)
|
||||
.Attr("mask_filter_value",
|
||||
|
|
|
@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
|
|||
size_t Count
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief rotary embedding for one hidden state vector
|
||||
*
|
||||
* @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
|
||||
* @param input: input tensor, of shape [dim]
|
||||
* @param sin: sin tensor, of shape [dim/2]
|
||||
* @param cos: cos tensor, of shape [dim/2]
|
||||
* @param dim: dimension of rotary embedding
|
||||
* @param interleaved: whether the real part and imaginary parts are interleaved
|
||||
* @param output: output tensor, of shape [dim]
|
||||
*/
|
||||
template <typename T>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow(
|
||||
const T* input,
|
||||
const T* sin,
|
||||
const T* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
T* output
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Whether current CPU supports FP16 acceleration.
|
||||
*/
|
||||
|
|
|
@ -6,7 +6,7 @@ Licensed under the MIT License.
|
|||
|
||||
Module Name:
|
||||
|
||||
fp16_neon_common.cpp
|
||||
cast_kernel_neon.cpp
|
||||
|
||||
Abstract:
|
||||
|
|
@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
|
|||
|
||||
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
|
||||
|
||||
//
|
||||
// Rotary embedding dispatch structure.
|
||||
//
|
||||
struct MLAS_ROPE_DISPATCH;
|
||||
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
|
||||
|
||||
|
||||
//
|
||||
// Quantized depthwise convolution kernels.
|
||||
//
|
||||
|
@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
|
|||
|
||||
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
|
||||
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
|
||||
|
||||
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
|
||||
};
|
||||
|
||||
inline
|
||||
|
|
|
@ -543,6 +543,7 @@ Return Value:
|
|||
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
|
||||
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
|
||||
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
|
||||
this->RopeDispatch = &MlasRopeDispatchNeon;
|
||||
|
||||
//
|
||||
// Check if the processor supports ASIMD dot product instructions.
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Intel Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements rotary embedding kernels for fp32/16.
|
||||
|
||||
--*/
|
||||
|
||||
#include "rotary_embedding.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow_FallBack(
|
||||
const T* input_data,
|
||||
const T* sin_data,
|
||||
const T* cos_data,
|
||||
size_t rotary_emb_dim,
|
||||
bool interleaved,
|
||||
T* output_data
|
||||
) {
|
||||
const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
|
||||
size_t cache_idx = 0;
|
||||
bool sign = false;
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = i & 1;
|
||||
j = sign ? i - 1 : i + 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i >= half_rotary_emb_dim);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
|
||||
float input_data_j = static_cast<float>(input_data[j]);
|
||||
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
|
||||
if (sign) {
|
||||
output_data_i += input_data_j * sin_data_cache_idx;
|
||||
} else {
|
||||
output_data_i -= input_data_j * sin_data_cache_idx;
|
||||
}
|
||||
output_data[i] = static_cast<T>(output_data_i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
template <>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow<float>(
|
||||
const float* input,
|
||||
const float* sin,
|
||||
const float* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
float* output
|
||||
) {
|
||||
const auto* dispatch = GetMlasPlatform().RopeDispatch;
|
||||
|
||||
if (dispatch == nullptr || dispatch->SRope == nullptr) {
|
||||
MlasRotaryEmbedOneRow_FallBack<float>(input, sin, cos, dim, interleaved, output);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch->SRope(input, sin, cos, dim, interleaved, output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow<MLAS_FP16>(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
) {
|
||||
const auto* dispatch = GetMlasPlatform().RopeDispatch;
|
||||
|
||||
if (dispatch == nullptr || dispatch->HRope == nullptr) {
|
||||
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input, sin, cos, dim, interleaved, output);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch->HRope(input, sin, cos, dim, interleaved, output);
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes kernel function prototypes and helper functions for
|
||||
implementing rotary embedding.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
struct MLAS_ROPE_DISPATCH {
|
||||
// rotary embedding kernel for fp32
|
||||
typedef void(SRope_Fn)(
|
||||
const float* input,
|
||||
const float* sin,
|
||||
const float* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
float* output
|
||||
);
|
||||
|
||||
SRope_Fn* SRope = nullptr;
|
||||
|
||||
// rotary embedding kernel for fp16
|
||||
typedef void(HRope_Fn)(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
);
|
||||
|
||||
HRope_Fn* HRope = nullptr;
|
||||
};
|
|
@ -0,0 +1,32 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the rotary embedding kernels for ARM NEON.
|
||||
|
||||
--*/
|
||||
|
||||
#include "rotary_embedding.h"
|
||||
#include "rotary_embedding_kernel_neon.h"
|
||||
|
||||
//
|
||||
// Kernel dispatch structure definition.
|
||||
//
|
||||
const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
|
||||
MLAS_ROPE_DISPATCH d;
|
||||
|
||||
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
|
||||
if (MlasFp16AccelerationSupported()) {
|
||||
d.HRope = rope_neon::RopeKernel_Fp16;
|
||||
}
|
||||
#endif
|
||||
return d;
|
||||
}();
|
|
@ -0,0 +1,37 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes function declarations and common helper functions for
|
||||
rotary embedding on ARM cpu.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
namespace rope_neon {
|
||||
|
||||
// Rotary embedding kernel for fp16. Embed one hidden state vector.
|
||||
void
|
||||
RopeKernel_Fp16(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
);
|
||||
|
||||
} // namespace rope_neon
|
|
@ -0,0 +1,253 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon_fp16.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the fp16 rotary embedding kernels for ARM NEON.
|
||||
|
||||
--*/
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <cassert>
|
||||
|
||||
#include "fp16_common.h"
|
||||
#include "rotary_embedding.h"
|
||||
#include "rotary_embedding_kernel_neon.h"
|
||||
|
||||
namespace rope_neon {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool interleaved>
|
||||
void
|
||||
RopeKernel_Fp16_Impl(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
);
|
||||
|
||||
template <>
|
||||
void
|
||||
RopeKernel_Fp16_Impl<false>(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
) {
|
||||
const size_t half_dim = dim >> 1;
|
||||
size_t i = 0, j = half_dim;
|
||||
for (; i + 7 < half_dim; i += 8, j += 8) {
|
||||
float16x8_t real = MlasLoadFloat16x8(input + i);
|
||||
float16x8_t imag = MlasLoadFloat16x8(input + j);
|
||||
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
|
||||
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
|
||||
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
|
||||
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreFloat16x8(output + i, real_out);
|
||||
MlasStoreFloat16x8(output + j, imag_out);
|
||||
}
|
||||
for (; i + 3 < half_dim; i += 4, j += 4) {
|
||||
float16x4_t real = MlasLoadFloat16x4(input + i);
|
||||
float16x4_t imag = MlasLoadFloat16x4(input + j);
|
||||
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
|
||||
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreFloat16x4(output + i, real_out);
|
||||
MlasStoreFloat16x4(output + j, imag_out);
|
||||
}
|
||||
if (half_dim - i == 3) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
|
||||
real = MlasLoadLaneFloat16x4<2>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
|
||||
imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out);
|
||||
} else if (half_dim - i == 2) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
|
||||
} else if (half_dim - i == 1) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void
|
||||
RopeKernel_Fp16_Impl<true>(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
) {
|
||||
size_t i = 0;
|
||||
for (; i + 15 < dim; i += 16) {
|
||||
float16x8_t x0 = MlasLoadFloat16x8(input + i);
|
||||
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
|
||||
float16x8_t real = vuzp1q_f16(x0, x1);
|
||||
float16x8_t imag = vuzp2q_f16(x0, x1);
|
||||
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
|
||||
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
|
||||
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
|
||||
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
|
||||
float16x8_t y0 = vzip1q_f16(real_out, imag_out);
|
||||
float16x8_t y1 = vzip2q_f16(real_out, imag_out);
|
||||
MlasStoreFloat16x8(output + i, y0);
|
||||
MlasStoreFloat16x8(output + i + 8, y1);
|
||||
}
|
||||
for (; i + 7 < dim; i += 8) {
|
||||
float16x4_t x0 = MlasLoadFloat16x4(input + i);
|
||||
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
|
||||
float16x4_t real = vuzp1_f16(x0, x1);
|
||||
float16x4_t imag = vuzp2_f16(x0, x1);
|
||||
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
|
||||
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
float16x4_t y0 = vzip1_f16(real_out, imag_out);
|
||||
float16x4_t y1 = vzip2_f16(real_out, imag_out);
|
||||
MlasStoreFloat16x4(output + i, y0);
|
||||
MlasStoreFloat16x4(output + i + 4, y1);
|
||||
}
|
||||
if (dim - i == 6) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
|
||||
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
|
||||
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 4, real_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out);
|
||||
} else if (dim - i == 4) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
|
||||
} else if (dim - i == 2) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void
|
||||
RopeKernel_Fp16(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
) {
|
||||
// real part and imaginary part must be paired
|
||||
assert(dim % 2 == 0);
|
||||
|
||||
const auto* input_impl = reinterpret_cast<const _mlas_fp16_*>(input);
|
||||
const auto* sin_impl = reinterpret_cast<const _mlas_fp16_*>(sin);
|
||||
const auto* cos_impl = reinterpret_cast<const _mlas_fp16_*>(cos);
|
||||
auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output);
|
||||
|
||||
if (interleaved) {
|
||||
RopeKernel_Fp16_Impl<true>(input_impl, sin_impl, cos_impl, dim, output_impl);
|
||||
} else {
|
||||
RopeKernel_Fp16_Impl<false>(input_impl, sin_impl, cos_impl, dim, output_impl);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rope_neon
|
|
@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
// Create execution frame for executing constant nodes.
|
||||
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
|
||||
is_sparse_initializer_check);
|
||||
is_sparse_initializer_check, logger);
|
||||
#else
|
||||
// Create execution frame for executing constant nodes.
|
||||
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
|
||||
[](std::string const&) { return false; });
|
||||
OptimizerExecutionFrame::Info info(
|
||||
{node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; },
|
||||
logger);
|
||||
#endif
|
||||
|
||||
std::vector<int> fetch_mlvalue_idxs;
|
||||
|
|
|
@ -190,6 +190,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
TransformerLevel level,
|
||||
const SessionOptions& session_options,
|
||||
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
|
||||
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
|
||||
|
@ -404,7 +405,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
}
|
||||
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
|
||||
logger);
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
|
@ -437,6 +439,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
const SessionOptions& session_options,
|
||||
const SatApplyContextVariant& apply_context,
|
||||
const IExecutionProvider& cpu_execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
|
||||
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
|
||||
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
|
||||
|
@ -490,7 +493,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
#ifndef DISABLE_CONTRIB_OPS
|
||||
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
|
||||
logger);
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
|
|
|
@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
|
|||
// going to a node that will need a Cast.
|
||||
//
|
||||
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph,
|
||||
const KernelRegistry& cpu_kernel_registry,
|
||||
const logging::Logger& logger) {
|
||||
// we can check if it's an isolated fp16 node
|
||||
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
|
||||
// does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node),
|
||||
|
@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto lookup_status = cpu_kernel_registry.TryFindKernel(
|
||||
kCpuExecutionProvider, node.OpType(), node.Domain(),
|
||||
node.SinceVersion(), type_constraint_map, &kernel_create_info);
|
||||
node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
|
||||
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
return false;
|
||||
}
|
||||
|
||||
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry,
|
||||
const logging::Logger& logger) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) {
|
||||
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
|
||||
node.SetExecutionProviderType("");
|
||||
}
|
||||
|
@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
return dst_bit_length <= src_bit_length;
|
||||
}
|
||||
|
||||
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
|
||||
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") ||
|
||||
(*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
if (force_cpu_fp32_)
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_));
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger));
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -44,7 +44,9 @@ NhwcConvLookup(
|
|||
return &(iter->second);
|
||||
}
|
||||
|
||||
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept
|
||||
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator,
|
||||
std::shared_ptr<KernelRegistry> cpu_kernel_registry,
|
||||
const logging::Logger& logger) noexcept
|
||||
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
|
||||
if (!cpu_kernel_registry) {
|
||||
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
|
||||
|
@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
|
||||
qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
|
||||
qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
|
||||
qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
|
||||
qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
|
||||
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
|
||||
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
|
||||
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
|
|||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
|
||||
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
|
|
|
@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed.
|
|||
class NhwcTransformer : public GraphTransformer {
|
||||
private:
|
||||
public:
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept;
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry,
|
||||
const logging::Logger& logger) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Usually called right after constructor, it shows whether
|
||||
|
|
|
@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
const InitializedTensorSet& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger)
|
||||
: execution_provider_(execution_provider),
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func) {
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func),
|
||||
logger_(logger) {
|
||||
allocator_ptr_ = std::make_shared<CPUAllocator>();
|
||||
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
|
||||
|
||||
|
@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
|
||||
const std::filesystem::path& /* model_path */,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger)
|
||||
: execution_provider_(execution_provider),
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func) {
|
||||
is_sparse_initializer_func_(is_sparse_initializer_func),
|
||||
logger_(logger) {
|
||||
allocator_ptr_ = std::make_shared<CPUAllocator>();
|
||||
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
|
||||
|
||||
|
@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
|
|||
Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const {
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out);
|
||||
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out);
|
||||
}
|
||||
|
||||
static Status TryCreateKernel(const Node& node,
|
||||
|
@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node,
|
|||
FuncManager& funcs_mgr,
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const ConfigOptions& config_options,
|
||||
const logging::Logger& logger,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger,
|
||||
&kernel_create_info));
|
||||
|
||||
static const AllocatorMap dummy_allocators;
|
||||
|
@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption
|
|||
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
|
||||
FuncManager func;
|
||||
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_,
|
||||
op_kernel);
|
||||
|
||||
// Kernel found in the CPU kernel registry
|
||||
|
|
|
@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
const InitializedTensorSet& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Info(const std::vector<const Node*>& nodes,
|
||||
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
|
||||
const std::filesystem::path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
~Info() = default;
|
||||
|
||||
|
@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
std::unique_ptr<NodeIndexInfo> node_index_info_;
|
||||
const IExecutionProvider& execution_provider_;
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info);
|
||||
};
|
||||
|
|
|
@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion(
|
|||
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
|
||||
}
|
||||
|
||||
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
||||
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) {
|
||||
constexpr size_t w_idx = 1;
|
||||
constexpr size_t w_zp_idx = 9;
|
||||
constexpr size_t r_idx = 2;
|
||||
|
@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) ||
|
||||
!graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) ||
|
||||
r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
<< " cannot locate recurrence tensor of const int8 type,"
|
||||
<< " int8 overflow might impact precision !";
|
||||
return false;
|
||||
|
@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) ||
|
||||
!graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) ||
|
||||
r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
|
||||
<< " unable to locate recurrence tensor or its zero point value,"
|
||||
<< " int8 overflow might impact precision !";
|
||||
return false;
|
||||
|
@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int
|
|||
if (graph_utils::IsSupportedOptypeVersionAndDomain(
|
||||
op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) {
|
||||
// This one has two set of quantized arguments
|
||||
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph);
|
||||
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger);
|
||||
continue; // go on to next operator node
|
||||
}
|
||||
|
||||
|
|
|
@ -291,7 +291,8 @@ SelectorManager::SelectorManager() {
|
|||
InitializeSelectorsMap();
|
||||
}
|
||||
|
||||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
|
||||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer,
|
||||
const logging::Logger& logger) const {
|
||||
std::vector<NodeGroup> qdq_selections;
|
||||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
|
@ -313,7 +314,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
|
||||
if (!versions.empty()) {
|
||||
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
|
||||
LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -329,7 +330,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
}
|
||||
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) {
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
|
@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
|||
|
||||
// Get QDQ NodeUnits first
|
||||
QDQ::SelectorManager selector_mgr;
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger);
|
||||
|
||||
for (const auto& qdq_selection : qdq_selections) {
|
||||
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
class GraphViewer;
|
||||
class Node;
|
||||
|
||||
|
@ -65,7 +67,7 @@ class SelectorManager {
|
|||
|
||||
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
|
||||
// Can be used in QDQ support in different EPs
|
||||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;
|
||||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const;
|
||||
|
||||
private:
|
||||
Selectors qdq_selectors_;
|
||||
|
@ -88,7 +90,7 @@ class SelectorManager {
|
|||
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
|
||||
// library whereas it should be able to be used by an EP with no dependency on optimizers.
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer);
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -17,13 +17,22 @@ class TransformerMemcpyImpl {
|
|||
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
|
||||
: graph_(graph), provider_(provider) {}
|
||||
|
||||
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);
|
||||
bool ModifyGraph(const KernelRegistryManager& schema_registries,
|
||||
const logging::Logger& logger,
|
||||
int& copy_node_counter);
|
||||
|
||||
private:
|
||||
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
|
||||
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
|
||||
void ProcessDefs(onnxruntime::Node& node,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger);
|
||||
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
const logging::Logger& logger);
|
||||
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
|
||||
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);
|
||||
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
|
||||
const InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger);
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl);
|
||||
|
@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
// find defs that require copy
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
// as we process the defs, collect all the initializers consumed at the current graph level
|
||||
ProcessDefs(node, kernel_registries, initializers_consumed);
|
||||
ProcessDefs(node, kernel_registries, initializers_consumed, logger);
|
||||
}
|
||||
|
||||
// for initializers shared by different providers, create dups
|
||||
if (ProcessInitializers(kernel_registries, initializers_consumed))
|
||||
if (ProcessInitializers(kernel_registries, initializers_consumed, logger))
|
||||
modified = true;
|
||||
|
||||
for (auto arg : graph_.GetInputs())
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : non_provider_input_defs_)
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : non_provider_output_defs_)
|
||||
BuildDefsMapping(arg, kernel_registries);
|
||||
BuildDefsMapping(arg, kernel_registries, logger);
|
||||
|
||||
for (auto arg : graph_.GetInputs())
|
||||
// For inputs we need to create a copy node only when the input is connected to both provider
|
||||
|
@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
return modified;
|
||||
}
|
||||
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed) {
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger) {
|
||||
auto node_provider_type = node.GetExecutionProviderType();
|
||||
if ((node_provider_type == provider_) ||
|
||||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
|
||||
|
@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
provider_nodes_.insert(&node);
|
||||
// note KernelCreateInfo might be nullptr for custom kernel
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci));
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci));
|
||||
|
||||
bool is_implicit_input = false;
|
||||
auto process_inputs =
|
||||
|
@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
}
|
||||
|
||||
// for non_provider defs, collect the nodes that expect it is provider tensor as input/output.
|
||||
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) {
|
||||
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
|
||||
const KernelRegistryManager& kernel_registries,
|
||||
const logging::Logger& logger) {
|
||||
for (auto& it : graph_.Nodes()) {
|
||||
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
|
||||
auto input_it =
|
||||
|
@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
|
|||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
|
||||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci));
|
||||
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
|
||||
if (arg_input_index != -1) {
|
||||
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
|
||||
}
|
||||
|
@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co
|
|||
// We duplicate any initializer that is used by both provider nodes and non-provider nodes
|
||||
// to ensure that provider nodes and non-provider nodes don't share initializers, as they
|
||||
// need to stay in different memory locations.
|
||||
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) {
|
||||
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries,
|
||||
const InitializedTensorSet& initializers_consumed,
|
||||
const logging::Logger& logger) {
|
||||
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacements;
|
||||
for (const auto& pair : initializers_consumed) {
|
||||
const auto& name = pair.first;
|
||||
|
@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
|
|||
auto dup_replacements = replacements;
|
||||
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci);
|
||||
auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
if (kci == nullptr) continue;
|
||||
if (kci->kernel_def == nullptr) continue;
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "hardware_core_enumerator.h"
|
||||
#include "core/platform/windows/env.h"
|
||||
#include <memory>
|
||||
#include <Windows.h>
|
||||
#include <assert.h>
|
||||
|
@ -83,6 +84,38 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
|
|||
// # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
|
||||
// # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
|
||||
auto cores = GetCoreInfo();
|
||||
#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__)
|
||||
const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI"
|
||||
bool isIntelSpecifiedPlatform = false;
|
||||
const int kVendorID_IntelSpecifiedPlatformIDs[3] = {
|
||||
// ExtendedModel, ExtendedFamily, Family Code, and Model Number
|
||||
0xa06a, // MTL
|
||||
0xc065, // ARL-H
|
||||
0xb065 // ARL-U
|
||||
};
|
||||
|
||||
int regs_leaf0[4];
|
||||
int regs_leaf1[4];
|
||||
__cpuid(regs_leaf0, 0);
|
||||
__cpuid(regs_leaf1, 0x1);
|
||||
|
||||
auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]);
|
||||
|
||||
for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) {
|
||||
if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) {
|
||||
isIntelSpecifiedPlatform = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isIntel) {
|
||||
if (isIntelSpecifiedPlatform) {
|
||||
// We want to exclude cores without an LLC
|
||||
return cores.LLCCores;
|
||||
} else {
|
||||
return cores.PhysicalCores;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return cores.LLCCores;
|
||||
}
|
||||
|
|
|
@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
|
|||
|
||||
const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node);
|
||||
if (cann_kernel_def == nullptr) {
|
||||
LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
|
||||
<< " node name: " << node.Name();
|
||||
LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
|
||||
<< " node name: " << node.Name();
|
||||
continue;
|
||||
}
|
||||
|
||||
candidates.push_back(node.Index());
|
||||
}
|
||||
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger());
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
continue;
|
||||
|
|
|
@ -151,7 +151,7 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
|
||||
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
|
||||
// To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary
|
||||
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) {
|
||||
|
|
|
@ -133,9 +133,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
|
||||
// to pass https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1563483&view=logs&j=f7cc61a9-cc70-56e7-b06c-4668ca17e426
|
||||
// ReductionOpTest.ReduceSum_half_bert
|
||||
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
|
||||
// skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros
|
||||
int32_t input_type;
|
||||
GetType(*input_defs[0], input_type, logger);
|
||||
if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
|
|
|
@ -13,6 +13,10 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/cpu/tensor/unsqueeze.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <TargetConditionals.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
|
@ -54,32 +58,50 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder,
|
||||
const Node& node, const logging::Logger& logger) {
|
||||
const auto& input_defs(node.InputDefs());
|
||||
TensorShapeVector axes;
|
||||
GetAxes(model_builder, node, axes);
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
GetShape(*input_defs[0], input_shape, logger);
|
||||
auto op = model_builder.CreateOperation(node, "reshape");
|
||||
AddOperationInput(*op, "x", input_defs[0]->Name());
|
||||
TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
|
||||
AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0]);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
[[maybe_unused]] const logging::Logger& logger) const {
|
||||
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
|
||||
const auto& input_defs(node.InputDefs());
|
||||
auto* coreml_squeeze = layer->mutable_squeeze();
|
||||
TensorShapeVector axes;
|
||||
GetAxes(model_builder, node, axes);
|
||||
std::vector<int64_t> input_shape;
|
||||
GetShape(*input_defs[0], input_shape, logger);
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
const auto& input_defs(node.InputDefs());
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
|
||||
std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "reshape";
|
||||
#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64
|
||||
// expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input
|
||||
if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) {
|
||||
HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims";
|
||||
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
|
||||
AddOperationInput(*op, "x", input_defs[0]->Name());
|
||||
|
||||
if (coreml_op_type == "squeeze") {
|
||||
if (!axes.empty()) {
|
||||
// coreml squeeze op does support negative axes
|
||||
AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
|
||||
}
|
||||
} else {
|
||||
TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
|
||||
AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
|
||||
if (!axes.empty()) {
|
||||
// coreml supports negative axes
|
||||
AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
|
||||
}
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0]);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
|
|
|
@ -408,7 +408,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
|
|||
: graph_viewer_(graph_viewer),
|
||||
logger_(logger),
|
||||
coreml_version_(coreml_version),
|
||||
coreml_compute_unit_(coreml_options.ComputeUnits()),
|
||||
coreml_options_(coreml_options),
|
||||
create_ml_program_(coreml_options.CreateMLProgram()),
|
||||
model_output_path_(GetModelOutputPath(create_ml_program_)),
|
||||
onnx_input_names_(std::move(onnx_input_names)),
|
||||
|
@ -989,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
|
|||
get_sanitized_io_info(std::move(input_output_info_)),
|
||||
std::move(scalar_outputs_),
|
||||
std::move(int64_outputs_),
|
||||
logger_, coreml_compute_unit_);
|
||||
logger_, coreml_options_);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
|
@ -999,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
|
|||
std::move(input_output_info_),
|
||||
std::move(scalar_outputs_),
|
||||
std::move(int64_outputs_),
|
||||
logger_, coreml_compute_unit_);
|
||||
logger_, coreml_options_);
|
||||
}
|
||||
|
||||
return model->LoadModel(); // load using CoreML API, including compilation
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/coreml/builders/coreml_spec.h"
|
||||
#include "core/providers/coreml/model/model.h"
|
||||
#include "core/providers/coreml/coreml_options.h"
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
// coremltools classes
|
||||
|
@ -22,8 +23,6 @@ class StorageWriter;
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
class CoreMLOptions;
|
||||
|
||||
namespace coreml {
|
||||
|
||||
class IOpBuilder;
|
||||
|
@ -218,7 +217,7 @@ class ModelBuilder {
|
|||
const GraphViewer& graph_viewer_;
|
||||
const logging::Logger& logger_;
|
||||
const int32_t coreml_version_;
|
||||
const uint32_t coreml_compute_unit_;
|
||||
CoreMLOptions coreml_options_;
|
||||
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
|
||||
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
|
||||
|
||||
|
|
|
@ -63,11 +63,14 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
|
|||
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
|
||||
{"NeuralNetwork", COREML_FLAG_USE_NONE},
|
||||
};
|
||||
std::unordered_set<std::string> valid_options = {
|
||||
const std::unordered_set<std::string_view> valid_options = {
|
||||
kCoremlProviderOption_MLComputeUnits,
|
||||
kCoremlProviderOption_ModelFormat,
|
||||
kCoremlProviderOption_RequireStaticInputShapes,
|
||||
kCoremlProviderOption_EnableOnSubgraphs,
|
||||
kCoremlProviderOption_SpecializationStrategy,
|
||||
kCoremlProviderOption_ProfileComputePlan,
|
||||
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
|
||||
};
|
||||
// Validate the options
|
||||
for (const auto& option : options) {
|
||||
|
@ -90,6 +93,16 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
|
|||
require_static_shape_ = option.second == "1";
|
||||
} else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) {
|
||||
enable_on_subgraph_ = option.second == "1";
|
||||
} else if (kCoremlProviderOption_SpecializationStrategy == option.first) {
|
||||
if (option.second != "Default" && option.second != "FastPrediction") {
|
||||
ORT_THROW("Invalid value for option ", option.first, ": ", option.second,
|
||||
". Valid values are Default and FastPrediction.");
|
||||
}
|
||||
strategy_ = option.second;
|
||||
} else if (kCoremlProviderOption_ProfileComputePlan == option.first) {
|
||||
profile_compute_plan_ = option.second == "1";
|
||||
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
|
||||
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,9 @@ class CoreMLOptions {
|
|||
bool create_mlprogram_{false};
|
||||
bool enable_on_subgraph_{false};
|
||||
uint32_t compute_units_{0};
|
||||
std::string strategy_;
|
||||
bool profile_compute_plan_{false};
|
||||
bool allow_low_precision_accumulation_on_gpu_{false};
|
||||
|
||||
public:
|
||||
explicit CoreMLOptions(uint32_t coreml_flags);
|
||||
|
@ -25,6 +28,9 @@ class CoreMLOptions {
|
|||
bool CreateMLProgram() const { return create_mlprogram_; }
|
||||
bool EnableOnSubgraph() const { return enable_on_subgraph_; }
|
||||
uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; }
|
||||
bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; }
|
||||
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
|
||||
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }
|
||||
|
||||
private:
|
||||
void ValidateAndParseProviderOption(const ProviderOptions& options);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
class CoreMLOptions;
|
||||
namespace coreml {
|
||||
|
||||
class Execution;
|
||||
|
@ -53,7 +54,7 @@ class Model {
|
|||
std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info,
|
||||
std::unordered_set<std::string>&& scalar_outputs,
|
||||
std::unordered_set<std::string>&& int64_outputs,
|
||||
const logging::Logger& logger, uint32_t coreml_compute_unit);
|
||||
const logging::Logger& logger, const CoreMLOptions& coreml_options);
|
||||
|
||||
~Model();
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "core/providers/coreml/model/host_utils.h"
|
||||
#include "core/providers/coreml/model/objc_str_utils.h"
|
||||
#include "core/providers/coreml/shape_utils.h"
|
||||
#include "core/providers/coreml/coreml_options.h"
|
||||
|
||||
// force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need
|
||||
// to manually do this
|
||||
|
@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// since __clang_major__ >= 15, MLComputePlan is introduced in <CoreML/CoreML.h>
|
||||
// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`.
|
||||
// The macro API_AVAILABLE should also be fine.
|
||||
// Otherwise, the compiler will complain `MLComputePlan` is not defined.
|
||||
// we define __clang_analyzer__ here is for bypass static analysis
|
||||
void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) {
|
||||
#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
|
||||
if (@available(macOS 14.4, iOS 17.4, *)) {
|
||||
[MLComputePlan loadContentsOfURL:compileUrl
|
||||
configuration:config
|
||||
completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) {
|
||||
if (!computePlan) {
|
||||
NSLog(@"Error loading compute plan: %@", error);
|
||||
// Handle error.
|
||||
return;
|
||||
}
|
||||
MLModelStructureProgram* program = computePlan.modelStructure.program;
|
||||
if (!program) {
|
||||
NSLog(@"Error loading program from compute plan., this is not a mlprogram model");
|
||||
return;
|
||||
}
|
||||
|
||||
MLModelStructureProgramFunction* mainFunction = program.functions[@"main"];
|
||||
if (!mainFunction) {
|
||||
NSLog(@"Error loading main function from program");
|
||||
return;
|
||||
}
|
||||
|
||||
NSArray<MLModelStructureProgramOperation*>* operations = mainFunction.block.operations;
|
||||
NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count);
|
||||
for (MLModelStructureProgramOperation* operation in operations) {
|
||||
// Get the compute device usage for the operation.
|
||||
MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation];
|
||||
id<MLComputeDeviceProtocol> preferredDevice = computeDeviceUsage.preferredComputeDevice;
|
||||
// Get the estimated cost of executing the operation.
|
||||
MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation];
|
||||
if (![operation.operatorName isEqualToString:@"const"]) {
|
||||
NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight);
|
||||
}
|
||||
}
|
||||
}];
|
||||
} else {
|
||||
NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Internal Execution class
|
||||
// This class is part of the model class and handles the calls into CoreML. Specifically, it performs
|
||||
// 1. Compile the model by given path for execution
|
||||
|
@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
|
|||
// 3. The compiled model will be removed in dealloc or removed using cleanup function
|
||||
class Execution {
|
||||
public:
|
||||
Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags);
|
||||
Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options);
|
||||
~Execution();
|
||||
|
||||
Status LoadModel();
|
||||
|
@ -320,13 +368,13 @@ class Execution {
|
|||
NSString* coreml_model_path_{nil};
|
||||
NSString* compiled_model_path_{nil};
|
||||
const logging::Logger& logger_;
|
||||
uint32_t coreml_compute_unit_{0};
|
||||
CoreMLOptions coreml_options_;
|
||||
MLModel* model_{nil};
|
||||
};
|
||||
|
||||
Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit)
|
||||
Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options)
|
||||
: logger_(logger),
|
||||
coreml_compute_unit_(coreml_compute_unit) {
|
||||
coreml_options_(coreml_options) {
|
||||
@autoreleasepool {
|
||||
coreml_model_path_ = util::Utf8StringToNSString(path.c_str());
|
||||
}
|
||||
|
@ -395,17 +443,41 @@ Status Execution::LoadModel() {
|
|||
compiled_model_path_ = [compileUrl path];
|
||||
|
||||
MLModelConfiguration* config = [[MLModelConfiguration alloc] init];
|
||||
|
||||
if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) {
|
||||
uint32_t coreml_compute_unit = coreml_options_.ComputeUnits();
|
||||
if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) {
|
||||
config.computeUnits = MLComputeUnitsCPUOnly;
|
||||
} else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) {
|
||||
} else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) {
|
||||
config.computeUnits = MLComputeUnitsCPUAndGPU;
|
||||
} else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
|
||||
} else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
|
||||
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
|
||||
} else {
|
||||
config.computeUnits = MLComputeUnitsAll;
|
||||
}
|
||||
|
||||
if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) {
|
||||
config.allowLowPrecisionAccumulationOnGPU = YES;
|
||||
}
|
||||
|
||||
// Set the specialization strategy to FastPrediction for macOS 10.15+
|
||||
// since __clang_major__ >= 15, optimizationHints is introduced in <CoreML/CoreML.h>
|
||||
// Same as above comments for why we are checking __clang_major__.
|
||||
// we define __clang_analyzer__ here is for bypass static analysis
|
||||
#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
|
||||
if (HAS_COREML8_OR_LATER) {
|
||||
MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init];
|
||||
if (coreml_options_.UseStrategy("FastPrediction")) {
|
||||
optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction;
|
||||
config.optimizationHints = optimizationHints;
|
||||
} else if (coreml_options_.UseStrategy("Default")) {
|
||||
optimizationHints.specializationStrategy = MLSpecializationStrategyDefault;
|
||||
config.optimizationHints = optimizationHints;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (coreml_options_.ProfileComputePlan()) {
|
||||
ProfileComputePlan(compileUrl, config);
|
||||
}
|
||||
|
||||
model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error];
|
||||
|
||||
if (error != nil || model_ == nil) {
|
||||
|
@ -524,8 +596,8 @@ Model::Model(const std::string& path,
|
|||
std::unordered_set<std::string>&& scalar_outputs,
|
||||
std::unordered_set<std::string>&& int64_outputs,
|
||||
const logging::Logger& logger,
|
||||
uint32_t coreml_flags)
|
||||
: execution_(std::make_unique<Execution>(path, logger, coreml_flags)),
|
||||
const CoreMLOptions& coreml_options)
|
||||
: execution_(std::make_unique<Execution>(path, logger, coreml_options)),
|
||||
model_input_names_(std::move(model_input_names)),
|
||||
model_output_names_(std::move(model_output_names)),
|
||||
input_output_info_(std::move(input_output_info)),
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/coreml/model/model.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class CoreMLOptions;
|
||||
namespace coreml {
|
||||
|
||||
class Execution {};
|
||||
|
@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/,
|
|||
std::unordered_set<std::string>&& scalar_outputs,
|
||||
std::unordered_set<std::string>&& int64_outputs,
|
||||
const logging::Logger& /*logger*/,
|
||||
uint32_t /*coreml_flags*/)
|
||||
const CoreMLOptions& /*coreml_flags*/)
|
||||
: execution_(std::make_unique<Execution>()),
|
||||
model_input_names_(std::move(model_input_names)),
|
||||
model_output_names_(std::move(model_output_names)),
|
||||
|
|
|
@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
|
||||
// These are usually shape related computation subgraphs
|
||||
// Following logic can be extended for other EPs
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
|
|
|
@ -62,7 +62,8 @@ namespace Dml
|
|||
const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
|
||||
const auto kernel_lookup = onnxruntime::KernelLookup{provider_type,
|
||||
gsl::make_span(®istry, 1),
|
||||
kernel_type_str_resolver};
|
||||
kernel_type_str_resolver,
|
||||
logger};
|
||||
|
||||
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos;
|
||||
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;
|
||||
|
|
|
@ -54,7 +54,8 @@ namespace Dml
|
|||
const auto kernelLookup = onnxruntime::KernelLookup(
|
||||
providerType,
|
||||
gsl::make_span(®istry, 1),
|
||||
kernelTypeStrResolver);
|
||||
kernelTypeStrResolver,
|
||||
logger);
|
||||
|
||||
onnxruntime::GraphViewer graphViewer(graph);
|
||||
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -95,7 +95,7 @@ namespace Dml
|
|||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
|
||||
{
|
||||
#ifdef ENABLE_GRAPH_COMPILATION
|
||||
return m_impl->GetCapability(graph, kernel_lookup);
|
||||
return m_impl->GetCapability(graph, kernel_lookup, *GetLogger());
|
||||
#else
|
||||
return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup);
|
||||
#endif
|
||||
|
@ -876,7 +876,8 @@ namespace Dml
|
|||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
ExecutionProviderImpl::GetCapability(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
const onnxruntime::logging::Logger& logger) const
|
||||
{
|
||||
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
|
||||
|
@ -900,7 +901,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
// Get the list of nodes that should stay on the CPU
|
||||
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes);
|
||||
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger);
|
||||
|
||||
for (size_t nodeIndex : toplogicalOrder)
|
||||
{
|
||||
|
|
|
@ -88,7 +88,8 @@ namespace Dml
|
|||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
GetCapability(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup
|
||||
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
const onnxruntime::logging::Logger& logger
|
||||
) const;
|
||||
|
||||
uint32_t GetSupportedDeviceDataTypeMask() const;
|
||||
|
|
|
@ -818,7 +818,7 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
|
|||
candidates.push_back(node.Index());
|
||||
tenative_candidates.push_back(node.Index());
|
||||
}
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0) {
|
||||
|
|
|
@ -32,8 +32,16 @@ namespace nnapi {
|
|||
|
||||
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices,
|
||||
TargetDeviceOption target_device_option)
|
||||
: nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique<Model>(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) {
|
||||
TargetDeviceOption target_device_option,
|
||||
const logging::Logger& logger)
|
||||
: nnapi_(nnapi_handle),
|
||||
graph_viewer_(graph_viewer),
|
||||
nnapi_model_{std::make_unique<Model>(nnapi_handle)},
|
||||
shaper_{graph_viewer},
|
||||
nnapi_target_devices_(nnapi_target_devices),
|
||||
target_device_option_(target_device_option),
|
||||
nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)),
|
||||
logger_(logger) {
|
||||
nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_;
|
||||
}
|
||||
|
||||
|
@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
|
|||
}
|
||||
|
||||
void ModelBuilder::PreprocessNodeUnits() {
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
|
||||
}
|
||||
|
||||
// Help to get all quantized operators' input and the NodeUnit(s) using the input
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
|
||||
struct NnApi;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
class GraphViewer;
|
||||
enum class DataLayout;
|
||||
class NodeUnit;
|
||||
|
@ -31,7 +33,8 @@ class ModelBuilder {
|
|||
using Shape = Shaper::Shape;
|
||||
|
||||
ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option);
|
||||
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option,
|
||||
const logging::Logger& logger);
|
||||
|
||||
common::Status Compile(std::unique_ptr<Model>& model);
|
||||
|
||||
|
@ -173,6 +176,9 @@ class ModelBuilder {
|
|||
// <1,1> <1,2> <1,3>
|
||||
InlinedVector<std::pair<size_t, int32_t>> operations_recorder_;
|
||||
#endif
|
||||
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// Convert the ONNX model to ANeuralNetworksModel
|
||||
common::Status Prepare();
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {}
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const {
|
||||
const auto& logger = *GetLogger();
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
// TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators)
|
||||
|
@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL;
|
||||
#endif
|
||||
}();
|
||||
LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
|
||||
LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
|
||||
|
||||
const nnapi::OpSupportCheckParams params{
|
||||
android_feature_level,
|
||||
|
@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
};
|
||||
|
||||
if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) {
|
||||
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
|
||||
LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
|
||||
<< params.android_feature_level
|
||||
<< "] is lower than minimal supported NNAPI API feature level ["
|
||||
<< ORT_NNAPI_MIN_API_LEVEL
|
||||
|
@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit to be checked for multiple times
|
||||
|
@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
node_unit_supported_result[node_unit] = supported;
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported
|
||||
LOGS(logger, VERBOSE) << "Node supported: [" << supported
|
||||
<< "] Operator type: [" << node.OpType()
|
||||
<< "] index: [" << node.Index()
|
||||
<< "] name: [" << node.Name()
|
||||
|
@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
// If the graph is partitioned in multiple subgraphs, and this may impact performance,
|
||||
// we want to give users a summary message at warning level.
|
||||
if (num_of_partitions > 1) {
|
||||
LOGS_DEFAULT(WARNING) << summary_msg;
|
||||
LOGS(logger, WARNING) << summary_msg;
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << summary_msg;
|
||||
LOGS(logger, INFO) << summary_msg;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context,
|
|||
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
using namespace android::nn::wrapper;
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
|
||||
Node& fused_node = fused_node_and_graph.fused_node;
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
|
||||
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_);
|
||||
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger);
|
||||
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
|
||||
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
|
||||
|
||||
|
|
|
@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
|
|||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger);
|
||||
|
||||
std::unordered_set<const NodeUnit*> seen_node_units;
|
||||
const auto& node_indices = src_graph.GetNodesInTopologicalOrder();
|
||||
|
|
|
@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
|
|||
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
|
||||
const onnxruntime::PathString& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModelLookupTable& qnn_models) {
|
||||
QnnModelLookupTable& qnn_models,
|
||||
int64_t max_spill_fill_size) {
|
||||
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
|
||||
NodeAttrHelper node_helper(main_context_node);
|
||||
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
|
||||
|
@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
|
|||
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
|
||||
static_cast<uint64_t>(context_binary.length()),
|
||||
main_context_node.Name(),
|
||||
qnn_models);
|
||||
qnn_models,
|
||||
max_spill_fill_size);
|
||||
}
|
||||
|
||||
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
|
||||
|
@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
|
|||
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
|
||||
static_cast<uint64_t>(buffer_size),
|
||||
main_context_node.Name(),
|
||||
qnn_models);
|
||||
qnn_models,
|
||||
max_spill_fill_size);
|
||||
}
|
||||
|
||||
Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
uint32_t total_context_size,
|
||||
int64_t& max_spill_fill_size,
|
||||
std::vector<int>& main_context_pos_list) {
|
||||
max_spill_fill_size = 0;
|
||||
int max_size_index = 0;
|
||||
for (uint32_t i = 0; i < total_context_size; ++i) {
|
||||
auto index = main_context_pos_list[i];
|
||||
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph);
|
||||
ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
|
||||
const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin();
|
||||
NodeAttrHelper node_helper(*ep_context_node);
|
||||
int64_t max_size = node_helper.Get(MAX_SIZE, static_cast<int64_t>(0));
|
||||
if (max_size > max_spill_fill_size) {
|
||||
max_spill_fill_size = max_size;
|
||||
max_size_index = i;
|
||||
}
|
||||
}
|
||||
if (0 != max_size_index) {
|
||||
int tmp_index = main_context_pos_list[0];
|
||||
main_context_pos_list[0] = main_context_pos_list[max_size_index];
|
||||
main_context_pos_list[max_size_index] = tmp_index;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const onnxruntime::PathString& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModelLookupTable& qnn_models,
|
||||
const logging::Logger& logger) {
|
||||
const logging::Logger& logger,
|
||||
int64_t max_spill_fill_size) {
|
||||
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
|
||||
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
|
||||
qnn_models);
|
||||
qnn_models, max_spill_fill_size);
|
||||
|
||||
// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
|
||||
if (!status.IsOK()) {
|
||||
|
@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model,
|
|||
const QnnModelLookupTable& qnn_models,
|
||||
const onnxruntime::PathString& context_cache_path,
|
||||
bool qnn_context_embed_mode,
|
||||
uint64_t max_spill_fill_buffer_size,
|
||||
const logging::Logger& logger) {
|
||||
auto& graph = model->MainGraph();
|
||||
|
||||
|
@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model,
|
|||
}
|
||||
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
|
||||
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
|
||||
ep_node.AddAttribute(MAX_SIZE, static_cast<int64_t>(max_spill_fill_buffer_size));
|
||||
}
|
||||
} else {
|
||||
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));
|
||||
|
|
|
@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
|
|||
static const std::string EP_SDK_VER = "ep_sdk_version";
|
||||
static const std::string PARTITION_NAME = "partition_name";
|
||||
static const std::string SOURCE = "source";
|
||||
static const std::string MAX_SIZE = "max_size";
|
||||
|
||||
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);
|
||||
|
||||
|
@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
|
|||
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
|
||||
const onnxruntime::PathString& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModelLookupTable& qnn_models);
|
||||
QnnModelLookupTable& qnn_models,
|
||||
int64_t max_spill_fill_size);
|
||||
|
||||
Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
uint32_t total_context_size,
|
||||
int64_t& max_spill_fill_size,
|
||||
std::vector<int>& main_context_pos_list);
|
||||
|
||||
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const onnxruntime::PathString& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModelLookupTable& qnn_models,
|
||||
const logging::Logger& logger);
|
||||
const logging::Logger& logger,
|
||||
int64_t max_spill_fill_size);
|
||||
|
||||
Status CreateEPContextNodes(Model* model,
|
||||
unsigned char* buffer,
|
||||
|
@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model,
|
|||
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
|
||||
const onnxruntime::PathString& context_cache_path,
|
||||
bool qnn_context_embed_mode,
|
||||
uint64_t max_spill_fill_buffer_size,
|
||||
const logging::Logger& logger);
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <string>
|
||||
#include "QnnOpDef.h"
|
||||
#include "HTP/QnnHtpPerfInfrastructure.h"
|
||||
#include "HTP/QnnHtpSystemContext.h"
|
||||
#include "CPU/QnnCpuCommon.h"
|
||||
// TODO: not exist for Windows yet
|
||||
// #include "GPU/QnnGpuCommon.h"
|
||||
|
@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() {
|
|||
}
|
||||
|
||||
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
|
||||
QnnHtpContext_CustomConfig_t customConfig;
|
||||
customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
|
||||
customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
|
||||
QnnHtpContext_CustomConfig_t custom_config;
|
||||
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
|
||||
custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
|
||||
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
|
||||
context_config_weight_sharing.customConfig = &customConfig;
|
||||
context_config_weight_sharing.customConfig = &custom_config;
|
||||
|
||||
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
|
||||
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
|
||||
|
@ -615,9 +616,9 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
|
|||
return context_buffer;
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
|
||||
std::string node_name,
|
||||
QnnModelLookupTable& qnn_models) {
|
||||
Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
|
||||
uint64_t buffer_length,
|
||||
uint64_t& max_spill_fill_buffer_size) {
|
||||
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
|
||||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
|
||||
nullptr == qnn_sys_interface_.systemContextFree;
|
||||
|
@ -638,7 +639,69 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
|
|||
|
||||
// binary_info life cycle is here
|
||||
// Binary info to graph info
|
||||
// retrieve Qnn graph infor from binary info
|
||||
// retrieve Qnn graph info from binary info
|
||||
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
|
||||
uint32_t graph_count = 0;
|
||||
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
|
||||
if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
|
||||
graph_count = binary_info->contextBinaryInfoV3.numGraphs;
|
||||
graphs_info = binary_info->contextBinaryInfoV3.graphs;
|
||||
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
|
||||
graph_count = binary_info->contextBinaryInfoV2.numGraphs;
|
||||
graphs_info = binary_info->contextBinaryInfoV2.graphs;
|
||||
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
|
||||
graph_count = binary_info->contextBinaryInfoV1.numGraphs;
|
||||
graphs_info = binary_info->contextBinaryInfoV1.graphs;
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < graph_count; ++i) {
|
||||
if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
|
||||
auto htp_graph_info = reinterpret_cast<QnnHtpSystemContext_GraphBlobInfo_t*>(graphs_info[i].graphInfoV3.graphBlobInfo);
|
||||
if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
|
||||
auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
|
||||
max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size;
|
||||
} else {
|
||||
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
|
||||
}
|
||||
} else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
|
||||
graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
|
||||
LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
|
||||
} else {
|
||||
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
|
||||
}
|
||||
}
|
||||
|
||||
LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
|
||||
std::string node_name,
|
||||
QnnModelLookupTable& qnn_models,
|
||||
int64_t max_spill_fill_size) {
|
||||
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
|
||||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
|
||||
nullptr == qnn_sys_interface_.systemContextFree;
|
||||
ORT_RETURN_IF(result, "Failed to get valid function pointer.");
|
||||
|
||||
QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
|
||||
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
|
||||
|
||||
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
|
||||
Qnn_ContextBinarySize_t binary_info_size{0};
|
||||
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
|
||||
static_cast<void*>(buffer),
|
||||
buffer_length,
|
||||
&binary_info,
|
||||
&binary_info_size);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
|
||||
|
||||
// binary_info life cycle is here
|
||||
// Binary info to graph info
|
||||
// retrieve Qnn graph info from binary info
|
||||
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
|
||||
uint32_t graph_count = 0;
|
||||
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
|
||||
|
@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
|
|||
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
|
||||
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;
|
||||
|
||||
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
|
||||
"Invalid function pointer for contextCreateFromBinary.");
|
||||
|
||||
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
|
||||
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
|
||||
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};
|
||||
|
||||
// Register spill fill buffer for multi context
|
||||
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
|
||||
|
||||
// The spill fill buffer is available since 2.28, API version starts from 2.21
|
||||
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
|
||||
QnnHtpContext_CustomConfig_t custom_config;
|
||||
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
|
||||
QnnHtpContext_GroupRegistration_t group_info;
|
||||
size_t current_contexts_size = GetQnnContextSize();
|
||||
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
|
||||
// note that we already move the context with max spill fill size to the beginning of the list
|
||||
group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
|
||||
group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
|
||||
custom_config.groupRegistration = group_info;
|
||||
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
|
||||
spill_fill_config.customConfig = &custom_config;
|
||||
#endif
|
||||
QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
|
||||
LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
|
||||
|
||||
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
|
||||
|
||||
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
|
||||
"Invalid function pointer for contextCreateFromBinary.");
|
||||
Qnn_ContextHandle_t context = nullptr;
|
||||
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
|
||||
device_handle_,
|
||||
|
@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
|
|||
buffer_length,
|
||||
&context,
|
||||
profile_backend_handle_);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
|
||||
contexts_.push_back(context);
|
||||
if (1 == graph_count) {
|
||||
// in case the EPContext node is generated from script
|
||||
|
@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
|
||||
// need to load system lib if load from Qnn context binary
|
||||
// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
|
||||
Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
|
||||
bool load_from_cached_context,
|
||||
bool need_load_system_lib) {
|
||||
std::lock_guard<std::mutex> lock(logger_mutex_);
|
||||
if (backend_setup_completed_) {
|
||||
LOGS(logger, VERBOSE) << "Backend setup already!";
|
||||
|
@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
|
|||
|
||||
LOGS(logger, VERBOSE) << "LoadBackend succeed.";
|
||||
|
||||
if (load_from_cached_context) {
|
||||
if (load_from_cached_context || need_load_system_lib) {
|
||||
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
|
||||
}
|
||||
|
||||
|
@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void QnnBackendManager::Split(std::vector<std::string>& split_string,
|
||||
const std::string& tokenized_string,
|
||||
const char separator) {
|
||||
split_string.clear();
|
||||
std::istringstream tokenized_string_stream(tokenized_string);
|
||||
while (!tokenized_string_stream.eof()) {
|
||||
std::string value;
|
||||
getline(tokenized_string_stream, value, separator);
|
||||
if (!value.empty()) {
|
||||
split_string.push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
|
||||
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
|
||||
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
|
||||
|
|
|
@ -93,9 +93,10 @@ class QnnBackendManager {
|
|||
|
||||
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
|
||||
std::string node_name,
|
||||
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
|
||||
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
|
||||
int64_t max_spill_fill_size);
|
||||
|
||||
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
|
||||
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);
|
||||
|
||||
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
|
||||
|
||||
|
@ -112,6 +113,10 @@ class QnnBackendManager {
|
|||
return contexts_[index];
|
||||
}
|
||||
|
||||
size_t GetQnnContextSize() {
|
||||
return contexts_.size();
|
||||
}
|
||||
|
||||
const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; }
|
||||
|
||||
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
|
||||
|
@ -145,8 +150,6 @@ class QnnBackendManager {
|
|||
|
||||
void ReleaseResources();
|
||||
|
||||
void Split(std::vector<std::string>& split_string, const std::string& tokenized_string, const char separator);
|
||||
|
||||
Status ExtractBackendProfilingInfo();
|
||||
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
|
||||
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
|
||||
|
@ -163,6 +166,10 @@ class QnnBackendManager {
|
|||
|
||||
Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);
|
||||
|
||||
Status GetMaxSpillFillBufferSize(unsigned char* buffer,
|
||||
uint64_t buffer_length,
|
||||
uint64_t& max_spill_fill_buffer_size);
|
||||
|
||||
private:
|
||||
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
|
|||
// valid throughout the lifetime of the ModelBuilder
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This name must be same with the EPContext node name
|
||||
const auto& graph_name = fused_node.Name();
|
||||
|
|
|
@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
|
|||
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
|
||||
}
|
||||
|
||||
bool enable_htp_weight_sharing = false;
|
||||
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
|
||||
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
|
||||
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
|
||||
if ("1" == htp_weight_sharing_enabled_pos->second) {
|
||||
enable_htp_weight_sharing_ = true;
|
||||
enable_htp_weight_sharing = true;
|
||||
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
|
||||
enable_htp_weight_sharing_ = false;
|
||||
enable_htp_weight_sharing = false;
|
||||
} else {
|
||||
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_
|
||||
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
|
||||
<< " only 0 or 1 allowed. Set to 0.";
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_;
|
||||
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
|
||||
}
|
||||
|
||||
// Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform
|
||||
enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map);
|
||||
|
||||
model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false,
|
||||
provider_options_map);
|
||||
|
||||
|
@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
|
|||
device_id_,
|
||||
htp_arch,
|
||||
soc_model,
|
||||
enable_htp_weight_sharing_);
|
||||
enable_htp_weight_sharing);
|
||||
|
||||
#ifdef _WIN32
|
||||
auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance();
|
||||
|
@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
|
||||
// It will load the QnnSystem lib if is_qnn_ctx_model=true, and
|
||||
// delay the Qnn context creation to Compile() using the cached context binary
|
||||
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model);
|
||||
// or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size
|
||||
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_);
|
||||
if (Status::OK() != rt) {
|
||||
LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage();
|
||||
return result;
|
||||
|
@ -713,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// remove is_qnn_ctx_model related code
|
||||
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map,
|
||||
|
@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
|
||||
std::vector<int> main_context_pos_list;
|
||||
ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list));
|
||||
uint32_t total_context_size = SafeInt<uint32_t>(main_context_pos_list.size());
|
||||
|
||||
int64_t max_spill_fill_size = 0;
|
||||
|
||||
// Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning
|
||||
// HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28
|
||||
if (total_context_size > 1) {
|
||||
ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size,
|
||||
max_spill_fill_size, main_context_pos_list));
|
||||
}
|
||||
|
||||
for (auto main_context_pos : main_context_pos_list) {
|
||||
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
|
||||
|
@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
context_cache_path,
|
||||
qnn_backend_manager_.get(),
|
||||
qnn_models,
|
||||
logger));
|
||||
logger,
|
||||
max_spill_fill_size));
|
||||
}
|
||||
|
||||
for (auto fused_node_and_graph : fused_nodes_and_graphs) {
|
||||
|
@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
// All partitioned graph share single QNN context, included in the same context binary
|
||||
uint64_t buffer_size(0);
|
||||
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
|
||||
// Get max spill fill buffer size
|
||||
uint64_t max_spill_fill_buffer_size = 0;
|
||||
if (enable_spill_fill_buffer_) {
|
||||
ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(),
|
||||
buffer_size,
|
||||
max_spill_fill_buffer_size));
|
||||
}
|
||||
qnn_ep_context_model_ = std::make_unique<Model>("qnn_ep_context_model", false, logger);
|
||||
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(),
|
||||
context_buffer.get(),
|
||||
|
@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
qnn_models_,
|
||||
context_cache_path,
|
||||
qnn_context_embed_mode_,
|
||||
max_spill_fill_buffer_size,
|
||||
logger));
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider {
|
|||
std::string context_node_name_prefix_ = "";
|
||||
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
|
||||
bool qnn_context_embed_mode_ = true;
|
||||
bool enable_htp_weight_sharing_ = false;
|
||||
int32_t vtcm_size_in_mb_ = 0;
|
||||
std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
|
||||
ModelMetadefIdGenerator metadef_id_generator_;
|
||||
|
@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider {
|
|||
uint32_t default_rpc_control_latency_ = 0;
|
||||
bool enable_HTP_FP16_precision_ = true;
|
||||
bool share_ep_contexts_ = false;
|
||||
bool enable_spill_fill_buffer_ = false;
|
||||
#ifdef _WIN32
|
||||
onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr;
|
||||
#endif
|
||||
|
|
|
@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
// For ROCM EP, exclude the subgraph that is preferred to be placed in CPU
|
||||
// These are usually shape related computation subgraphs
|
||||
// Following logic can be extended for other EPs
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0)
|
||||
|
|
|
@ -294,7 +294,8 @@ std::unique_ptr<IDataTransfer> CreateGPUDataTransfer();
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger);
|
||||
|
||||
std::string GetEnvironmentVar(const std::string& var_name);
|
||||
|
||||
|
@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<UInt4x2>() {
|
|||
|
||||
namespace QDQ {
|
||||
inline std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer* graph_viewer) {
|
||||
return g_host->QDQ__GetAllNodeUnits(graph_viewer);
|
||||
GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) {
|
||||
return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger);
|
||||
}
|
||||
} // namespace QDQ
|
||||
|
||||
|
|
|
@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) {
|
|||
|
||||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) {
|
||||
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) {
|
||||
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
|
||||
}
|
||||
|
||||
namespace profiling {
|
||||
|
|
|
@ -202,7 +202,8 @@ struct ProviderHost {
|
|||
|
||||
virtual std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
|
||||
const IExecutionProvider::IKernelLookup& kernel_lookup,
|
||||
gsl::span<const NodeIndex> tentative_nodes) = 0;
|
||||
gsl::span<const NodeIndex> tentative_nodes,
|
||||
const logging::Logger& logger) = 0;
|
||||
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
|
||||
|
@ -890,7 +891,7 @@ struct ProviderHost {
|
|||
virtual std::unique_ptr<Node__EdgeIterator> NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0;
|
||||
|
||||
virtual std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0;
|
||||
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0;
|
||||
|
||||
// Model
|
||||
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
|
||||
|
|
|
@ -34,7 +34,8 @@ namespace onnxruntime {
|
|||
|
||||
namespace vsi {
|
||||
namespace npu {
|
||||
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) {
|
||||
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger)
|
||||
: graph_viewer_(graph_viewer), logger_(logger) {
|
||||
Prepare();
|
||||
context_ = tim::vx::Context::Create();
|
||||
graph_ = context_->CreateGraph();
|
||||
|
@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g
|
|||
}
|
||||
|
||||
bool GraphEP::Prepare() {
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
|
||||
for (const auto& node_unit : node_unit_holder_) {
|
||||
auto quant_op_type = util::GetQuantizedOpType(*node_unit);
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ struct NodeIOInfo {
|
|||
|
||||
class GraphEP {
|
||||
public:
|
||||
explicit GraphEP(const GraphViewer& graph_viewer);
|
||||
explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
~GraphEP() {}
|
||||
|
||||
bool Prepare();
|
||||
|
@ -104,6 +104,7 @@ class GraphEP {
|
|||
// In the form of {input_name, [NodeUnit(s) using the input]}
|
||||
std::unordered_map<std::string, std::vector<const NodeUnit*>> all_quantized_op_inputs_;
|
||||
const GraphViewer& graph_viewer_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
// Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is
|
||||
// valid throughout the lifetime of the ModelBuilder
|
||||
|
|
|
@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {}
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const {
|
||||
const auto& logger = *GetLogger();
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
if (graph_viewer.IsSubgraph()) {
|
||||
|
@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit to be checked for multiple times
|
||||
|
@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
}
|
||||
|
||||
Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
||||
OrtKernelContext* context) {
|
||||
OrtKernelContext* context,
|
||||
const logging::Logger& logger) {
|
||||
Ort::KernelContext ctx(context);
|
||||
size_t num_in = ctx.GetInputCount();
|
||||
const size_t num_inputs = graph_ep->GetGraphInputs().size();
|
||||
|
@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
|||
}
|
||||
|
||||
if (!graph_ep->GetGraph()->Run()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to run graph.";
|
||||
LOGS(logger, ERROR) << "Failed to run graph.";
|
||||
}
|
||||
for (size_t i = 0; i < ctx.GetOutputCount(); i++) {
|
||||
auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor;
|
||||
|
@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
|
|||
|
||||
Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_viewer = fused_node_graph.filtered_graph;
|
||||
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer);
|
||||
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer, logger);
|
||||
|
||||
for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
|
||||
LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
|
||||
<< graph_viewer.IsInitializedTensor(tensor->Name());
|
||||
auto input = std::make_shared<vsi::npu::GraphIOInfo>();
|
||||
input->name = tensor->Name();
|
||||
|
@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
graph_ep->GetGraphInputs().push_back(input);
|
||||
}
|
||||
for (auto tensor : graph_viewer.GetOutputs()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
|
||||
LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
|
||||
auto output = std::make_shared<vsi::npu::GraphIOInfo>();
|
||||
output->name = tensor->Name();
|
||||
output->is_initializer = false;
|
||||
|
@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
if (node != &node_unit.GetNode()) {
|
||||
continue;
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]";
|
||||
LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]";
|
||||
vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit);
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(INFO) << "Verifying graph";
|
||||
LOGS(logger, INFO) << "Verifying graph";
|
||||
graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile();
|
||||
if (!graph_ep->GetCompiled()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to verify graph.";
|
||||
LOGS(logger, ERROR) << "Failed to verify graph.";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "Graph has been verified successfully.";
|
||||
LOGS(logger, INFO) << "Graph has been verified successfully.";
|
||||
}
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
|
@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
|
|||
[graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */,
|
||||
OrtKernelContext* context) {
|
||||
std::lock_guard<std::mutex> lock(this->GetMutex());
|
||||
Status res = ComputeStateFunc(graph_ep.get(), context);
|
||||
Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger());
|
||||
return res;
|
||||
};
|
||||
|
||||
|
|
|
@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
1, 8,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
9, 10,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
11, 12,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
13, 20,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
21,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -798,7 +798,8 @@ std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapa
|
|||
candidates.push_back(node.Index());
|
||||
tenative_candidates.push_back(node.Index());
|
||||
}
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
|
||||
|
||||
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& node_index : candidates) {
|
||||
if (cpu_nodes.count(node_index) > 0) {
|
||||
|
|
|
@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
|
|||
}
|
||||
}
|
||||
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
|
||||
const logging::Logger& logger, bool allow_empty_input) {
|
||||
const auto& node_arg_name = node_arg.Name();
|
||||
const auto* shape_proto = node_arg.Shape();
|
||||
// Optional tensors can be indicated by an empty name, just ignore it.
|
||||
|
@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
|
|||
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
|
||||
return false;
|
||||
}
|
||||
if (dim.dim_value() == 0) {
|
||||
if (dim.dim_value() == 0 && !allow_empty_input) {
|
||||
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
|
|||
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
|
||||
}
|
||||
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
|
||||
const logging::Logger& logger, bool allow_empty_input = false);
|
||||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
|
|
|
@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
|
|||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
if (!IsTensorShapeSupported(*input, node_name, logger)) {
|
||||
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче