Merge branch 'main' into Cjian/jdk17-js

# Conflicts:
#	js/react_native/android/gradle/wrapper/gradle-wrapper.properties
This commit is contained in:
Jian Chen 2024-12-10 10:50:08 -08:00
Родитель 9b7735f08c 5f7b9d0245
Коммит 2c077e194f
138 изменённых файлов: 1635 добавлений и 1053 удалений

Просмотреть файл

@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
)
target_sources(onnxruntime_mlas PRIVATE
@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set(mlas_platform_preprocess_srcs
@ -367,6 +372,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@ -384,8 +391,9 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@ -395,8 +403,9 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

Просмотреть файл

@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>(Optional) Hardware architecture.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>max_size</tt> : int</dt>
<dd>max size in the context. Usage depend on the EP.</dd>
<dt><tt>notes</tt> : string</dt>
<dd>(Optional) Some notes for the model</dd>
<dt><tt>onnx_model_filename</tt> : string</dt>

Просмотреть файл

@ -8,6 +8,9 @@
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace logging {
class Logger;
}
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
@ -33,6 +36,7 @@ class KernelRegistry {
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger,
const KernelCreateInfo** out) const;
// map of type constraint name to required type
@ -42,6 +46,7 @@ class KernelRegistry {
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;
/**
@ -61,13 +66,15 @@ class KernelRegistry {
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver) {
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger) {
const KernelCreateInfo* info;
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
return st.IsOK();
}
@ -83,6 +90,7 @@ class KernelRegistry {
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;
// Check whether the types of inputs/outputs of the given node match the extra

Просмотреть файл

@ -53,6 +53,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
@ -84,6 +85,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);

Просмотреть файл

@ -47,8 +47,20 @@ enum COREMLFlags {
// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead.
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat";
// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES
static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes";
static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs";
// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property
// Core ML segments the models compute graph and specializes each segment for the target compute device.
// This process can affect the model loading time and the prediction latency.
// Use this option to tailor the specialization strategy for your model.
static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy";
// Profile the Core ML MLComputePlan.
// This logs the hardware each operator is dispatched to and the estimated execution time.
// Intended for developer usage but provide useful diagnostic information if performance is not as expected.
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
#ifdef __cplusplus
extern "C" {

Просмотреть файл

@ -3667,6 +3667,9 @@ struct OrtApi {
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
* "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
* - "0": Default. Disabled.
* - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
@ -4612,6 +4615,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.17.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
_In_ OrtSessionOptions* options,
@ -4629,6 +4634,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options,
@ -4642,7 +4649,10 @@ struct OrtApi {
* \param[in] mem_info OrtMemoryInfo instance
* \param[in] count_or_bytes How many bytes is this scratch buffer
* \param[out] out A pointer to the scrach buffer
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
@ -4653,6 +4663,8 @@ struct OrtApi {
* \param[out] out A pointer to OrtAllocator
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
@ -4674,6 +4686,8 @@ struct OrtApi {
* \param[in] num_external_initializer_files Number of external files
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options,
_In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names,
@ -4696,6 +4710,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
@ -4714,6 +4730,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
@ -4735,6 +4753,8 @@ struct OrtApi {
* \param[in] adapter OrtLoraAdapter instance
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);
@ -4753,6 +4773,8 @@ struct OrtApi {
* \param[in] kv_len Number of elements in the keys and values arrays
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);

Просмотреть файл

@ -198,19 +198,6 @@ module.exports = {
'_OrtReleaseTensor',
'_OrtRun',
'_OrtRunWithBinding',
'_OrtTrainingCopyParametersFromBuffer',
'_OrtTrainingCopyParametersToBuffer',
'_OrtTrainingCreateSession',
'_OrtTrainingEvalStep',
'_OrtTrainingGetModelInputOutputCount',
'_OrtTrainingGetModelInputOutputName',
'_OrtTrainingGetParametersSize',
'_OrtTrainingLazyResetGrad',
'_OrtTrainingLoadCheckpoint',
'_OrtTrainingOptimizerStep',
'_OrtTrainingReleaseCheckpoint',
'_OrtTrainingReleaseSession',
'_OrtTrainingRunTrainStep',
],
},
],

Просмотреть файл

@ -3,7 +3,6 @@
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession } from './training-session.js';
/**
* @ignore
@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
): Promise<SessionHandler.ReturnType>;
}
/**
* Represent a handler instance of a training inference session.
*
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];
lazyResetGrad(): Promise<void>;
runTrainStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}
/**
* Represent a backend that provides implementation of model inferencing.
*
@ -84,14 +56,6 @@ export interface Backend {
uriOrBuffer: string | Uint8Array,
options?: InferenceSession.SessionOptions,
): Promise<InferenceSessionHandler>;
createTrainingSessionHandler?(
checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
options: InferenceSession.SessionOptions,
): Promise<TrainingSessionHandler>;
}
export { registerBackend } from './backend-impl.js';

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the MIT License.
import { env as envImpl } from './env-impl.js';
import { TryGetGlobalType } from './type-helper.js';
export declare namespace Env {
export type WasmPathPrefix = string;
@ -14,7 +15,6 @@ export declare namespace Env {
* If not modified, the filename of the .wasm file is:
* - `ort-wasm-simd-threaded.wasm` for default build
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.wasm` for training build
*/
wasm?: URL | string;
/**
@ -25,7 +25,6 @@ export declare namespace Env {
* If not modified, the filename of the .mjs file is:
* - `ort-wasm-simd-threaded.mjs` for default build
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.mjs` for training build
*/
mjs?: URL | string;
}
@ -200,22 +199,16 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
*
* see comments on {@link Tensor.GpuBufferType}
*/
adapter: unknown;
adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
*
* see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
*/
readonly device: unknown;
readonly device: TryGetGlobalType<'GPUDevice'>;
/**
* Set or get whether validate input content.
*

Просмотреть файл

@ -26,4 +26,3 @@ export * from './tensor-factory.js';
export * from './trace.js';
export * from './onnx-model.js';
export * from './onnx-value.js';
export * from './training-session.js';

Просмотреть файл

@ -4,6 +4,7 @@
import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js';
import { OnnxModelOptions } from './onnx-model.js';
import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js';
import { TryGetGlobalType } from './type-helper.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@ -282,7 +283,7 @@ export declare namespace InferenceSession {
extends WebNNExecutionProviderName,
Omit<WebNNContextOptions, 'deviceType'>,
Required<Pick<WebNNContextOptions, 'deviceType'>> {
context: unknown /* MLContext */;
context: TryGetGlobalType<'MLContext'>;
}
/**
@ -291,8 +292,8 @@ export declare namespace InferenceSession {
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice
*/
export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName {
context: unknown /* MLContext */;
gpuDevice: unknown /* GPUDevice */;
context: TryGetGlobalType<'MLContext'>;
gpuDevice: TryGetGlobalType<'GPUDevice'>;
}
/**

Просмотреть файл

@ -4,6 +4,7 @@
import { TensorFactory } from './tensor-factory.js';
import { Tensor as TensorImpl } from './tensor-impl.js';
import { TypedTensorUtils } from './tensor-utils.js';
import { TryGetGlobalType } from './type-helper.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@ -131,24 +132,19 @@ export declare namespace Tensor {
*/
export type TextureDataTypes = 'float32';
type GpuBufferTypeFallback = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
/**
* type alias for WebGPU buffer
*
* The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types"
* requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen
* carefully according to the TypeScript version being used. This means so far there is not a way to keep every
* TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
*
* for more info see https://github.com/gpuweb/types/issues/127
*/
export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
export type GpuBufferType = TryGetGlobalType<'GPUBuffer', GpuBufferTypeFallback>;
type MLTensorTypeFallback = { destroy(): void };
/**
* type alias for WebNN MLTensor
*
* The specification for WebNN's MLTensor is currently in flux.
*/
export type MLTensorType = unknown;
export type MLTensorType = TryGetGlobalType<'MLTensor', MLTensorTypeFallback>;
/**
* supported data types for constructing a tensor from a WebGPU buffer

Просмотреть файл

@ -1,273 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { resolveBackendAndExecutionProviders } from './backend-impl.js';
import { SessionHandler, TrainingSessionHandler } from './backend.js';
import { InferenceSession as InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { Tensor } from './tensor.js';
import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js';
type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
type FetchesType = InferenceSession.FetchesType;
type ReturnType = InferenceSession.ReturnType;
type RunOptions = InferenceSession.RunOptions;
const noBackendErrMsg: string =
'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}
get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
static async create(
trainingOptions: TrainingSessionCreateOptions,
sessionOptions?: SessionOptions,
): Promise<TrainingSession> {
const evalModel: string | Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState,
trainingOptions.trainModel,
evalModel,
optimizerModel,
optionsWithValidatedEPs,
);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
}
/**
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(
inputNames: readonly string[],
outputNames: readonly string[],
feeds: FeedsType,
arg1?: FetchesType | RunOptions,
arg2?: RunOptions,
): [SessionHandler.FetchesType, RunOptions] {
const fetches: { [name: string]: OnnxValue | null } = {};
let options: RunOptions = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
throw new TypeError(
"'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.",
);
}
let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof Tensor) {
throw new TypeError("'fetches' cannot be a Tensor");
}
if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError("'fetches' cannot be an empty array.");
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError("'fetches' must be a string array or an object.");
}
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
} else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}
if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
} else {
options = arg1 as RunOptions;
}
}
} else if (typeof arg1 !== 'undefined') {
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
}
// check if all inputs are in feed
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of outputNames) {
fetches[name] = null;
}
}
return [fetches, options];
}
/**
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
* and changes it into a map of Tensors.
*
* @param results
* @returns
*/
convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
const returnValue: { [name: string]: OnnxValue } = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof Tensor) {
returnValue[key] = result;
} else {
returnValue[key] = new Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}
async lazyResetGrad(): Promise<void> {
await this.handler.lazyResetGrad();
}
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(
this.trainingInputNames,
this.trainingOutputNames,
feeds,
arg1,
arg2,
);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}
runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] = this.typeNarrowingForRunStep(
this.evalInputNames,
this.evalOutputNames,
feeds,
arg1,
arg2,
);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}
async getParametersSize(trainableOnly = true): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}
async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
// of parameters
if (array.length !== 4 * paramsSize) {
throw new Error(
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.',
);
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}
async getContiguousParameters(trainableOnly = true): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}
async release(): Promise<void> {
return this.handler.dispose();
}
}

Просмотреть файл

@ -1,206 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
/* eslint-disable @typescript-eslint/no-redeclare */
export declare namespace TrainingSession {
/**
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
*/
type UriOrBuffer = string | Uint8Array;
}
/**
* Represent a runtime instance of an ONNX training session,
* which contains a model that can be trained, and, optionally,
* an eval and optimizer model.
*/
export interface TrainingSession {
// #region run()
/**
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
* runOptimizerStep.
*/
lazyResetGrad(): Promise<void>;
/**
* Run TrainStep asynchronously with the given feeds and options.
*
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single train step with the given inputs and options.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
// #endregion
// #region copy parameters
/**
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
* the parameters) elements of all the parameters in the training state.
*
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;
/**
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
* parameters of type Float32.
*
* @param buffer - A Uint8Array representation of Float32 parameters.
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
*/
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
/**
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
* Currently, only supporting models with parameters of type Float32.
*
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
* for which requires_grad is set to true. Default value is true.
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion
// #region release()
/**
* Release the inference session and the underlying resources.
*/
release(): Promise<void>;
// #endregion
// #region metadata
/**
* Get input names of the loaded training model.
*/
readonly trainingInputNames: readonly string[];
/**
* Get output names of the loaded training model.
*/
readonly trainingOutputNames: readonly string[];
/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalInputNames: readonly string[];
/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];
// #endregion
}
/**
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
*/
export interface TrainingSessionCreateOptions {
/**
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
*/
checkpointState: TrainingSession.UriOrBuffer;
/**
* URI or buffer for the .onnx training file.
*/
trainModel: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx optimizer model file.
*/
optimizerModel?: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx eval model file.
*/
evalModel?: TrainingSession.UriOrBuffer;
}
/**
* Defines method overload possibilities for creating a TrainingSession.
*/
export interface TrainingSessionFactory {
// #region create()
/**
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
*
* @param trainingOptions specify models and checkpoints to load into the Training Session
* @param sessionOptions specify configuration for training session behavior
*
* @returns Promise that resolves to a TrainingSession object
*/
create(
trainingOptions: TrainingSessionCreateOptions,
sessionOptions?: InferenceSession.SessionOptions,
): Promise<TrainingSession>;
// #endregion
}
// eslint-disable-next-line @typescript-eslint/naming-convention
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;

Просмотреть файл

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* A helper type to get certain types if they are declared in global scope.
*
* For example, if you installed "@webgpu/types" as a dev dependency, then `TryGetTypeIfDeclared<'GPUDevice'>` will
* be type `GPUDevice`, otherwise it will be type `unknown`.
*
*
* We don't want to introduce "@webgpu/types" as a dependency of this package because:
*
* (1) For JavaScript users, it's not needed. For TypeScript users, they can install it as dev dependency themselves.
*
* (2) because "@webgpu/types" requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its
* version need to be chosen carefully according to the TypeScript version being used. This means so far there is not a
* way to keep every TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
*
* for more info see https://github.com/gpuweb/types/issues/127
*
* Update (2024-08-07): The reason (2) may be no longer valid. Most people should be using TypeScript >= 5.1 by now.
* However, we are still not sure whether introducing "@webgpu/types" as direct dependency is a good idea. We find this
* type helper is useful for TypeScript users.
*
* @ignore
*/
export type TryGetGlobalType<Name extends string, Fallback = unknown> = typeof globalThis extends {
[k in Name]: { prototype: infer T };
}
? T
: Fallback;

Просмотреть файл

@ -1,6 +1,7 @@
{
"entryPoints": ["lib/index.ts"],
"excludeInternal": true,
"intentionallyNotExported": ["TryGetGlobalType"],
"name": "ONNX Runtime JavaScript API",
"readme": "none",
"cleanOutputDir": true

16
js/node/package-lock.json сгенерированный
Просмотреть файл

@ -276,12 +276,12 @@
"dev": true
},
"node_modules/axios": {
"version": "1.6.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"version": "1.7.9",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
"integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
"dev": true,
"dependencies": {
"follow-redirects": "^1.15.0",
"follow-redirects": "^1.15.6",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
@ -1581,12 +1581,12 @@
"dev": true
},
"axios": {
"version": "1.6.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"version": "1.7.9",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
"integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
"dev": true,
"requires": {
"follow-redirects": "^1.15.0",
"follow-redirects": "^1.15.6",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}

Двоичные данные
js/react_native/android/gradle/wrapper/gradle-wrapper.jar поставляемый

Двоичный файл не отображается.

Просмотреть файл

@ -2,5 +2,7 @@ distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d
distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

35
js/react_native/android/gradlew поставляемый
Просмотреть файл

@ -55,7 +55,7 @@
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
@ -80,13 +80,11 @@ do
esac
done
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
APP_NAME="Gradle"
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
@ -133,22 +131,29 @@ location of your Java installation."
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
if ! command -v java >/dev/null 2>&1
then
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC2039,SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC2039,SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
@ -193,11 +198,15 @@ if "$cygwin" || "$msys" ; then
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Collect all arguments for the java command:
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
# and any embedded shellness will be escaped.
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
# treated as '${Hostname}' itself on the command line.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \

21
js/react_native/android/gradlew.bat поставляемый
Просмотреть файл

@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@ -42,11 +43,11 @@ set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
echo. 1>&2
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2
echo. 1>&2
echo Please set the JAVA_HOME variable in your environment to match the 1>&2
echo location of your Java installation. 1>&2
goto fail
@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
echo. 1>&2
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2
echo. 1>&2
echo Please set the JAVA_HOME variable in your environment to match the 1>&2
echo location of your Java installation. 1>&2
goto fail

Просмотреть файл

@ -487,7 +487,7 @@ export const prepareInputOutputTensor = (
}
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const gpuBuffer = tensor[2].gpuBuffer;
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
const registerBuffer = wasm.jsepRegisterBuffer;

12
js/web/package-lock.json сгенерированный
Просмотреть файл

@ -861,9 +861,9 @@
}
},
"node_modules/cross-spawn": {
"version": "6.0.5",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
"integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
"version": "6.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
"integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
"dev": true,
"dependencies": {
"nice-try": "^1.0.4",
@ -4312,9 +4312,9 @@
}
},
"cross-spawn": {
"version": "6.0.5",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
"integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
"version": "6.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
"integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
"dev": true,
"requires": {
"nice-try": "^1.0.4",

Просмотреть файл

@ -4,6 +4,7 @@
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
using onnxruntime::concurrency::ThreadPool;
@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
bool sign = false;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = i & 1;
j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
float input_data_j = static_cast<float>(input_data[j]);
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
if (sign) {
output_data_i += input_data_j * sin_data_cache_idx;
} else {
output_data_i -= input_data_j * sin_data_cache_idx;
}
output_data[i] = static_cast<T>(output_data_i);
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
if (rotary_emb_dim < head_size) {
std::memcpy(output_data + rotary_emb_dim,
input_data + rotary_emb_dim,
(head_size - rotary_emb_dim) * sizeof(T));
}
}
});

Просмотреть файл

@ -31,6 +31,7 @@ Subgraph::Subgraph(
allocator_(nullptr),
is_output_float16_(false) {
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
used_implicit_inputs = std::vector<bool>(num_implicit_inputs, true);
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state,
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
for (auto& entry : node.ImplicitInputDefs()) {
feed_names.push_back(entry->Name());
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
const auto& implicit_input_defs = node.ImplicitInputDefs();
for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) {
const auto* entry = implicit_input_defs[i];
int idx;
if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) {
feed_names.push_back(entry->Name());
} else {
--num_implicit_inputs;
used_implicit_inputs[i] = false;
}
}
InlinedVector<OrtDevice> feed_locations;

Просмотреть файл

@ -31,6 +31,7 @@ class Subgraph {
const GraphViewer& subgraph; // The subgraph
int num_implicit_inputs;
std::vector<bool> used_implicit_inputs;
int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // Same as subgraph_output_names.size()

Просмотреть файл

@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
}
// Pass through implicit inputs.
for (const auto* entry : implicit_inputs) {
decoder_feeds.push_back(*entry);
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
const auto* entry = implicit_inputs[i];
if (used_implicit_inputs[i]) {
decoder_feeds.push_back(*entry);
}
}
return Status::OK();

Просмотреть файл

@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
pinned_allocator,
location));
for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
const auto* entry = implicit_inputs[i];
if (used_implicit_inputs[i]) {
feeds.push_back(*entry);
}
}
return Status::OK();

Просмотреть файл

@ -138,7 +138,8 @@ class PlannerImpl {
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan,
const logging::Logger& logger)
: context_(&context),
plan_(plan),
parent_node_(parent_node),
@ -148,14 +149,15 @@ class PlannerImpl {
kernel_create_info_map_(kernel_create_info_map),
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
ort_value_name_idx_map_(ort_value_name_idx_map) {}
ort_value_name_idx_map_(ort_value_name_idx_map),
logger_(logger) {
}
Status CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
const PathString& partition_config_file,
const logging::Logger& logger);
const PathString& partition_config_file);
private:
gsl::not_null<const ISequentialPlannerContext*> context_;
@ -183,6 +185,12 @@ class PlannerImpl {
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;
// logger_ is not currently used in a minimal build
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
[[maybe_unused]]
#endif
const logging::Logger& logger_;
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
struct OrtValueInfo {
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
@ -213,6 +221,7 @@ class PlannerImpl {
FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point)
: ml_value(ort_value), deallocate_point(dealloc_point) {}
};
// freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when
// they became free (more recently freed earlier in the list).
std::list<FreeBufferInfo> freelist_;
@ -225,7 +234,8 @@ class PlannerImpl {
}
int& UseCount(OrtValueIndex n) {
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size());
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(),
"invalid value index: ", n, " against size ", ort_value_info_.size());
return ort_value_info_[n].usecount;
}
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }
@ -335,9 +345,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
@ -361,9 +371,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
@ -397,8 +407,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
#endif
}
}
@ -448,8 +458,8 @@ class PlannerImpl {
return true;
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
<< producer_node->Name() << " as it has external outputs.";
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
<< producer_node->Name() << " as it has external outputs.";
#endif
}
}
@ -1198,9 +1208,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
@ -1241,9 +1251,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
@ -1290,8 +1300,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
#endif
}
}
@ -1869,8 +1879,7 @@ class PlannerImpl {
}
#ifndef ORT_ENABLE_STREAM
void PartitionIntoStreams(const logging::Logger& /*logger*/,
const ExecutionProviders& /*execution_providers*/,
void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/,
const PathString& /*partition_config_file*/) {
if (graph_viewer_.NumberOfNodes() > 0) {
stream_nodes_.push_back({});
@ -1915,11 +1924,11 @@ class PlannerImpl {
#else
void
PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
void PartitionIntoStreams(const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_,
context_->GetExecutionOrder());
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
plan_.node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
const PathString& partition_config_file,
const logging::Logger& logger) {
const PathString& partition_config_file) {
// 1. partition graph into streams
PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file);
PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file);
// 2. initialize the plan based on stream partition result
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan(
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map, context, *plan);
ort_value_name_idx_map, context, *plan, logger);
return planner.CreatePlan(
#ifdef ORT_ENABLE_STREAM
stream_handle_registry,
#endif
partition_config_file,
logger);
partition_config_file);
}
#ifdef ORT_ENABLE_STREAM

Просмотреть файл

@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger) {
// automatic conversion from const std::vector&
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex());
@ -83,7 +84,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {
candidates.push(consumer_node->Index());
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
}
}
return Status::OK();
@ -159,9 +160,9 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
if (place_in_cpu) {
cpu_nodes.insert(cur);
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
<< " capable of executing this node";
LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution "
"on other EPs capable of executing this node";
for (auto* output : node->OutputDefs()) {
cpu_output_args.insert(output);
}

Просмотреть файл

@ -9,6 +9,9 @@
#include "core/graph/graph_viewer.h"
namespace onnxruntime {
namespace logging {
class Logger;
}
/**
Returns a list of nodes that are preferred on CPU.
@ -19,6 +22,7 @@ namespace onnxruntime {
*/
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger);
} // namespace onnxruntime

Просмотреть файл

@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep,
};
} // namespace
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) {
auto& current_ep = params.current_ep.get();
const auto& ep_type = current_ep.Type();
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) {
LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
"the layout transformer.";
return Status::OK();
}
@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
kernel_registries_for_ep,
kernel_registry_mgr.GetKernelTypeStrResolver()};
kernel_registry_mgr.GetKernelTypeStrResolver(),
logger};
auto& graph = params.graph.get();
auto& capabilities = params.capabilities.get();
@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
const KernelRegistryManager& kernel_registry_mgr,
const IExecutionProvider& current_ep,
const logging::Logger& logger,
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
const auto& ep_type = current_ep.Type();
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
kernel_registries_for_ep,
kernel_registry_mgr.GetKernelTypeStrResolver()};
kernel_registry_mgr.GetKernelTypeStrResolver(),
logger};
// TODO: Provide EP with a capability to look inside the functions.
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
GraphPartitioner::Mode mode,
int& fused_node_unique_id,
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
const layout_transformation::DebugGraphFn& debug_graph_fn) {
const layout_transformation::DebugGraphFn& debug_graph_fn,
const logging::Logger& logger) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
transform_layout_fn, debug_graph_fn));
transform_layout_fn, debug_graph_fn, logger));
}
}
@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::cref(transform_layout_fn),
std::cref(debug_graph_fn)};
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
return Status::OK();
}
@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
if (n != nullptr) {
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
nodes_to_compile.push_back(n);
capabilities_to_compile.push_back(std::move(capability));
} else {
@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_mgr,
Graph& graph,
const logging::Logger& logger,
InlinedHashSet<std::string>& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_mgr,
*subgraph,
logger,
not_inlined,
inlined_count));
}
@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
InlinedHashSet<NodeIndex> claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger,
capabilities));
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
if (nodes.size() == 1) {
@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
const ExecutionProviders& execution_providers,
KernelRegistryManager& kernel_registry_manager) {
KernelRegistryManager& kernel_registry_manager,
const logging::Logger& logger) {
bool modified_graph = false;
auto& graph = partition_params.graph.get();
@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function,
partition_params.debug_graph_fn));
partition_params.debug_graph_fn,
logger));
}
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params,
KernelRegistryManager& kernel_registry_mgr,
IExecutionProvider& current_ep) {
IExecutionProvider& current_ep,
const logging::Logger& logger) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
auto& graph = partition_params.graph.get();
@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
auto& subgraph = *entry.second;
PartitionParams subgraph_partition_params = partition_params;
subgraph_partition_params.graph = std::ref(subgraph);
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep));
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep,
logger));
}
}
@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
};
// clang-format on
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
return Status::OK();
}
@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
// Simplified partitioning where custom EPs may produce compiled nodes.
static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
const ExecutionProviders& execution_providers,
KernelRegistryManager& kernel_registry_manager) {
KernelRegistryManager& kernel_registry_manager,
const logging::Logger& logger) {
// process full graph with each EP
for (const auto& ep : execution_providers) {
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep));
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger));
}
return Status::OK();
@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_manager,
graph,
logger,
not_inlined,
inlined_count));
@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
providers_, kernel_registry_mgr_));
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
#endif //! defined(ORT_MINIMAL_BUILD)
} else {
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params,
providers_, kernel_registry_mgr_));
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger));
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Просмотреть файл

@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
public:
KernelLookup(ProviderType provider_type,
gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries,
const IKernelTypeStrResolver& kernel_type_str_resolver)
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger)
: provider_type_{provider_type},
kernel_registries_{kernel_registries},
kernel_type_str_resolver_{kernel_type_str_resolver} {
kernel_type_str_resolver_{kernel_type_str_resolver},
logger_{logger} {
ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified.");
}
const KernelCreateInfo* LookUpKernel(const Node& node) const override {
const KernelCreateInfo* kernel_create_info{};
for (const auto& registry : kernel_registries_) {
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_,
const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_,
&kernel_create_info);
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
return kernel_create_info;
@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
ProviderType provider_type_;
const gsl::span<const gsl::not_null<const KernelRegistry*>> kernel_registries_;
const IKernelTypeStrResolver& kernel_type_str_resolver_;
const logging::Logger& logger_;
};
} // namespace onnxruntime

Просмотреть файл

@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const {
const auto& node_provider = node.GetExecutionProviderType();
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
std::ostream_iterator<std::string>(oss, "\n"));
oss << ")";
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
}
@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger,
const KernelCreateInfo** out) const {
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out);
return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out);
}
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const {
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out);
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out);
}
static bool KernelDefCompatible(int version, const KernelDef& kernel_def,
@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const {
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider));
if (out) *out = nullptr;
@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
std::ostream_iterator<std::string>(oss, "\n"));
oss << ")";
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
}

Просмотреть файл

@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistr
}
#endif
Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
Status KernelRegistryManager::SearchKernelRegistry(const Node& node, const logging::Logger& logger,
/*out*/ const KernelCreateInfo** kernel_create_info) const {
Status status;
@ -82,7 +82,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
}
for (auto& registry : custom_kernel_registries_) {
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
if (status.IsOK()) {
return status;
}
@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
}
if (p != nullptr) {
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
if (status.IsOK()) {
return status;
}
@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
}
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r,
const Node& node,
const std::string& provider_type,
const logging::Logger& logger) {
const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver());
return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(),
logger);
});
}

Просмотреть файл

@ -67,13 +67,14 @@ class KernelRegistryManager {
// This function assumes the node is already assigned to an execution provider
// Don't call this function before graph partition is done
Status SearchKernelRegistry(const Node& node,
Status SearchKernelRegistry(const Node& node, const logging::Logger& logger,
/*out*/ const KernelCreateInfo** kernel_create_info) const;
/**
* Whether this node can be run on this provider
*/
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type,
const logging::Logger& logger);
Status CreateKernel(const Node& node,
const IExecutionProvider& execution_provider,

Просмотреть файл

@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
bool saving_ort_format) {
for (auto& node : graph_.Nodes()) {
const KernelCreateInfo* kci = nullptr;
auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
if (!status.IsOK() && saving_ort_format) {
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
// in that case we assigned the node to that EP but do not compile it into a fused node.
@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
// if that's not possible for some reason we can fallback to the CPU EP implementation.
node.SetExecutionProviderType(kCpuExecutionProvider);
status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
}
ORT_RETURN_IF_ERROR(status);

Просмотреть файл

@ -3335,6 +3335,11 @@ void RegisterContribSchemas() {
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
.Attr(
"max_size",
"max size in the context. Usage depend on the EP.",
AttributeProto::INT,
static_cast<int64_t>(0))
.AllowUncheckedAttributes()
.Input(
0,

Просмотреть файл

@ -9,7 +9,7 @@
#include "core/graph/constants.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/shape_inference_functions.h"
#include "onnx/onnx-ml.pb.h" // ?
#include "core/graph/onnx_protobuf.h"
// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from
// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build
@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx);
void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape,
int input1Idx, int input2Idx);
namespace defs::math::utils {
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
}
} // namespace ONNX_NAMESPACE
@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
}
}
if (all_lengths_known) {
output_shape->mutable_dim(axis)->set_dim_value(total_length);
}
}));
if (all_lengths_known) {
output_shape->mutable_dim(axis)->set_dim_value(total_length);
}
}));
ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema()
.SetDoc("Return elements, either from X or Y, depending on condition.")
@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::INT, static_cast<int64_t>(0))
.Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT, OPTIONAL_VALUE)
.Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
.Attr("past_present_share_buffer",
"Corresponding past and present are same tensor, its shape is "
"(2, batch_size, num_heads, max_sequence_length, head_size)",
AttributeProto::INT, OPTIONAL_VALUE)
.Attr("mask_filter_value",

Просмотреть файл

@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
size_t Count
);
/**
* @brief rotary embedding for one hidden state vector
*
* @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
* @param input: input tensor, of shape [dim]
* @param sin: sin tensor, of shape [dim/2]
* @param cos: cos tensor, of shape [dim/2]
* @param dim: dimension of rotary embedding
* @param interleaved: whether the real part and imaginary parts are interleaved
* @param output: output tensor, of shape [dim]
*/
template <typename T>
void
MLASCALL
MlasRotaryEmbedOneRow(
const T* input,
const T* sin,
const T* cos,
size_t dim,
bool interleaved,
T* output
);
/**
* @brief Whether current CPU supports FP16 acceleration.
*/

Просмотреть файл

@ -6,7 +6,7 @@ Licensed under the MIT License.
Module Name:
fp16_neon_common.cpp
cast_kernel_neon.cpp
Abstract:

Просмотреть файл

@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
//
// Rotary embedding dispatch structure.
//
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
//
// Quantized depthwise convolution kernels.
//
@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
};
inline

Просмотреть файл

@ -543,6 +543,7 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->RopeDispatch = &MlasRopeDispatchNeon;
//
// Check if the processor supports ASIMD dot product instructions.

Просмотреть файл

@ -0,0 +1,101 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding.cpp
Abstract:
This module implements rotary embedding kernels for fp32/16.
--*/
#include "rotary_embedding.h"
namespace {
template <typename T>
void
MLASCALL
MlasRotaryEmbedOneRow_FallBack(
const T* input_data,
const T* sin_data,
const T* cos_data,
size_t rotary_emb_dim,
bool interleaved,
T* output_data
) {
const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
size_t cache_idx = 0;
bool sign = false;
size_t j = 0;
for (size_t i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = i & 1;
j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
float input_data_j = static_cast<float>(input_data[j]);
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
if (sign) {
output_data_i += input_data_j * sin_data_cache_idx;
} else {
output_data_i -= input_data_j * sin_data_cache_idx;
}
output_data[i] = static_cast<T>(output_data_i);
}
}
} // namespace
template <>
void
MLASCALL
MlasRotaryEmbedOneRow<float>(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
) {
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->SRope == nullptr) {
MlasRotaryEmbedOneRow_FallBack<float>(input, sin, cos, dim, interleaved, output);
return;
}
dispatch->SRope(input, sin, cos, dim, interleaved, output);
}
template <>
void
MLASCALL
MlasRotaryEmbedOneRow<MLAS_FP16>(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
) {
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->HRope == nullptr) {
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input, sin, cos, dim, interleaved, output);
return;
}
dispatch->HRope(input, sin, cos, dim, interleaved, output);
}

Просмотреть файл

@ -0,0 +1,46 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding.h
Abstract:
This module includes kernel function prototypes and helper functions for
implementing rotary embedding.
--*/
#pragma once
#include "mlasi.h"
struct MLAS_ROPE_DISPATCH {
// rotary embedding kernel for fp32
typedef void(SRope_Fn)(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
);
SRope_Fn* SRope = nullptr;
// rotary embedding kernel for fp16
typedef void(HRope_Fn)(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
);
HRope_Fn* HRope = nullptr;
};

Просмотреть файл

@ -0,0 +1,32 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon.cpp
Abstract:
This module implements the rotary embedding kernels for ARM NEON.
--*/
#include "rotary_embedding.h"
#include "rotary_embedding_kernel_neon.h"
//
// Kernel dispatch structure definition.
//
const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
MLAS_ROPE_DISPATCH d;
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
if (MlasFp16AccelerationSupported()) {
d.HRope = rope_neon::RopeKernel_Fp16;
}
#endif
return d;
}();

Просмотреть файл

@ -0,0 +1,37 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon.h
Abstract:
This module includes function declarations and common helper functions for
rotary embedding on ARM cpu.
--*/
#pragma once
#include <arm_neon.h>
#include "mlasi.h"
namespace rope_neon {
// Rotary embedding kernel for fp16. Embed one hidden state vector.
void
RopeKernel_Fp16(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
);
} // namespace rope_neon

Просмотреть файл

@ -0,0 +1,253 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon_fp16.cpp
Abstract:
This module implements the fp16 rotary embedding kernels for ARM NEON.
--*/
#include <arm_neon.h>
#include <cassert>
#include "fp16_common.h"
#include "rotary_embedding.h"
#include "rotary_embedding_kernel_neon.h"
namespace rope_neon {
namespace {
template <bool interleaved>
void
RopeKernel_Fp16_Impl(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
);
template <>
void
RopeKernel_Fp16_Impl<false>(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
) {
const size_t half_dim = dim >> 1;
size_t i = 0, j = half_dim;
for (; i + 7 < half_dim; i += 8, j += 8) {
float16x8_t real = MlasLoadFloat16x8(input + i);
float16x8_t imag = MlasLoadFloat16x8(input + j);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
MlasStoreFloat16x8(output + i, real_out);
MlasStoreFloat16x8(output + j, imag_out);
}
for (; i + 3 < half_dim; i += 4, j += 4) {
float16x4_t real = MlasLoadFloat16x4(input + i);
float16x4_t imag = MlasLoadFloat16x4(input + j);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreFloat16x4(output + i, real_out);
MlasStoreFloat16x4(output + j, imag_out);
}
if (half_dim - i == 3) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
real = MlasLoadLaneFloat16x4<2>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
MlasStoreLaneFloat16x4<2>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out);
} else if (half_dim - i == 2) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
} else if (half_dim - i == 1) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
}
}
template <>
void
RopeKernel_Fp16_Impl<true>(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
) {
size_t i = 0;
for (; i + 15 < dim; i += 16) {
float16x8_t x0 = MlasLoadFloat16x8(input + i);
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
float16x8_t real = vuzp1q_f16(x0, x1);
float16x8_t imag = vuzp2q_f16(x0, x1);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
float16x8_t y0 = vzip1q_f16(real_out, imag_out);
float16x8_t y1 = vzip2q_f16(real_out, imag_out);
MlasStoreFloat16x8(output + i, y0);
MlasStoreFloat16x8(output + i + 8, y1);
}
for (; i + 7 < dim; i += 8) {
float16x4_t x0 = MlasLoadFloat16x4(input + i);
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
float16x4_t real = vuzp1_f16(x0, x1);
float16x4_t imag = vuzp2_f16(x0, x1);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
float16x4_t y0 = vzip1_f16(real_out, imag_out);
float16x4_t y1 = vzip2_f16(real_out, imag_out);
MlasStoreFloat16x4(output + i, y0);
MlasStoreFloat16x4(output + i + 4, y1);
}
if (dim - i == 6) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
MlasStoreLaneFloat16x4<2>(output + i + 4, real_out);
MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out);
} else if (dim - i == 4) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
} else if (dim - i == 2) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
}
}
} // namespace
void
RopeKernel_Fp16(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
) {
// real part and imaginary part must be paired
assert(dim % 2 == 0);
const auto* input_impl = reinterpret_cast<const _mlas_fp16_*>(input);
const auto* sin_impl = reinterpret_cast<const _mlas_fp16_*>(sin);
const auto* cos_impl = reinterpret_cast<const _mlas_fp16_*>(cos);
auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output);
if (interleaved) {
RopeKernel_Fp16_Impl<true>(input_impl, sin_impl, cos_impl, dim, output_impl);
} else {
RopeKernel_Fp16_Impl<false>(input_impl, sin_impl, cos_impl, dim, output_impl);
}
}
} // namespace rope_neon

Просмотреть файл

@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
#if !defined(DISABLE_SPARSE_TENSORS)
// Create execution frame for executing constant nodes.
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
is_sparse_initializer_check);
is_sparse_initializer_check, logger);
#else
// Create execution frame for executing constant nodes.
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
[](std::string const&) { return false; });
OptimizerExecutionFrame::Info info(
{node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; },
logger);
#endif
std::vector<int> fetch_mlvalue_idxs;

Просмотреть файл

@ -190,6 +190,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
@ -404,7 +405,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
}
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
logger);
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}
@ -437,6 +439,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
@ -490,7 +493,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
#ifndef DISABLE_CONTRIB_OPS
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry),
logger);
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}

Просмотреть файл

@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
// going to a node that will need a Cast.
//
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph,
const KernelRegistry& cpu_kernel_registry,
const logging::Logger& logger) {
// we can check if it's an isolated fp16 node
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
// does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node),
@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
const KernelCreateInfo* kernel_create_info{};
const auto lookup_status = cpu_kernel_registry.TryFindKernel(
kCpuExecutionProvider, node.OpType(), node.Domain(),
node.SinceVersion(), type_constraint_map, &kernel_create_info);
node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
return true;
}
@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
return false;
}
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry,
const logging::Logger& logger) {
for (auto& node : graph.Nodes()) {
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) {
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
node.SetExecutionProviderType("");
}
@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
return dst_bit_length <= src_bit_length;
}
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") ||
(*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
return true;
}
@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
if (force_cpu_fp32_)
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_));
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger));
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();

Просмотреть файл

@ -44,7 +44,9 @@ NhwcConvLookup(
return &(iter->second);
}
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator,
std::shared_ptr<KernelRegistry> cpu_kernel_registry,
const logging::Logger& logger) noexcept
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
if (!cpu_kernel_registry) {
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<Ker
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(

Просмотреть файл

@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed.
class NhwcTransformer : public GraphTransformer {
private:
public:
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept;
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry,
const logging::Logger& logger) noexcept;
/**
* @brief Usually called right after constructor, it shows whether

Просмотреть файл

@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
const InitializedTensorSet& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
const logging::Logger& logger)
: execution_provider_(execution_provider),
is_sparse_initializer_func_(is_sparse_initializer_func) {
is_sparse_initializer_func_(is_sparse_initializer_func),
logger_(logger) {
allocator_ptr_ = std::make_shared<CPUAllocator>();
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
const std::filesystem::path& /* model_path */,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func)
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
const logging::Logger& logger)
: execution_provider_(execution_provider),
is_sparse_initializer_func_(is_sparse_initializer_func) {
is_sparse_initializer_func_(is_sparse_initializer_func),
logger_(logger) {
allocator_ptr_ = std::make_shared<CPUAllocator>();
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const {
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out);
return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out);
}
static Status TryCreateKernel(const Node& node,
@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node,
FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
const ConfigOptions& config_options,
const logging::Logger& logger,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelCreateInfo* kernel_create_info = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger,
&kernel_create_info));
static const AllocatorMap dummy_allocators;
@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
FuncManager func;
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_,
op_kernel);
// Kernel found in the CPU kernel registry

Просмотреть файл

@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
const InitializedTensorSet& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
const logging::Logger& logger);
Info(const std::vector<const Node*>& nodes,
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
const std::function<bool(const std::string&)>& is_sparse_initializer_func,
const logging::Logger& logger);
~Info() = default;
@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
std::unique_ptr<NodeIndexInfo> node_index_info_;
const IExecutionProvider& execution_provider_;
const std::function<bool(const std::string&)>& is_sparse_initializer_func_;
const logging::Logger& logger_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info);
};

Просмотреть файл

@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion(
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
}
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) {
constexpr size_t w_idx = 1;
constexpr size_t w_zp_idx = 9;
constexpr size_t r_idx = 2;
@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) ||
!graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) ||
r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
<< " cannot locate recurrence tensor of const int8 type,"
<< " int8 overflow might impact precision !";
return false;
@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) ||
!graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) ||
r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
<< " unable to locate recurrence tensor or its zero point value,"
<< " int8 overflow might impact precision !";
return false;
@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int
if (graph_utils::IsSupportedOptypeVersionAndDomain(
op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) {
// This one has two set of quantized arguments
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph);
modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger);
continue; // go on to next operator node
}

Просмотреть файл

@ -291,7 +291,8 @@ SelectorManager::SelectorManager() {
InitializeSelectorsMap();
}
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer,
const logging::Logger& logger) const {
std::vector<NodeGroup> qdq_selections;
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto* node = graph_viewer.GetNode(index);
@ -313,7 +314,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
if (!versions.empty()) {
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType();
continue;
}
}
@ -329,7 +330,7 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
}
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer) {
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) {
// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger);
for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);

Просмотреть файл

@ -15,7 +15,9 @@
#endif
namespace onnxruntime {
namespace logging {
class Logger;
}
class GraphViewer;
class Node;
@ -65,7 +67,7 @@ class SelectorManager {
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
// Can be used in QDQ support in different EPs
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const;
private:
Selectors qdq_selectors_;
@ -88,7 +90,7 @@ class SelectorManager {
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
// library whereas it should be able to be used by an EP with no dependency on optimizers.
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer);
GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger);
} // namespace QDQ
} // namespace onnxruntime

Просмотреть файл

@ -17,13 +17,22 @@ class TransformerMemcpyImpl {
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
: graph_(graph), provider_(provider) {}
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);
bool ModifyGraph(const KernelRegistryManager& schema_registries,
const logging::Logger& logger,
int& copy_node_counter);
private:
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
void ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
const InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl);
@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// find defs that require copy
for (auto& node : graph_.Nodes()) {
// as we process the defs, collect all the initializers consumed at the current graph level
ProcessDefs(node, kernel_registries, initializers_consumed);
ProcessDefs(node, kernel_registries, initializers_consumed, logger);
}
// for initializers shared by different providers, create dups
if (ProcessInitializers(kernel_registries, initializers_consumed))
if (ProcessInitializers(kernel_registries, initializers_consumed, logger))
modified = true;
for (auto arg : graph_.GetInputs())
BuildDefsMapping(arg, kernel_registries);
BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : non_provider_input_defs_)
BuildDefsMapping(arg, kernel_registries);
BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : non_provider_output_defs_)
BuildDefsMapping(arg, kernel_registries);
BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : graph_.GetInputs())
// For inputs we need to create a copy node only when the input is connected to both provider
@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
return modified;
}
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed) {
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci));
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci));
bool is_implicit_input = false;
auto process_inputs =
@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
}
// for non_provider defs, collect the nodes that expect it is provider tensor as input/output.
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) {
void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger) {
for (auto& it : graph_.Nodes()) {
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
auto input_it =
@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci));
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
if (arg_input_index != -1) {
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
}
@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co
// We duplicate any initializer that is used by both provider nodes and non-provider nodes
// to ensure that provider nodes and non-provider nodes don't share initializers, as they
// need to stay in different memory locations.
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) {
bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries,
const InitializedTensorSet& initializers_consumed,
const logging::Logger& logger) {
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacements;
for (const auto& pair : initializers_consumed) {
const auto& name = pair.first;
@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
auto dup_replacements = replacements;
const KernelCreateInfo* kci = nullptr;
auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci);
auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
if (kci == nullptr) continue;
if (kci->kernel_def == nullptr) continue;

Просмотреть файл

@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "hardware_core_enumerator.h"
#include "core/platform/windows/env.h"
#include <memory>
#include <Windows.h>
#include <assert.h>
@ -83,6 +84,38 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
// # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
// # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
auto cores = GetCoreInfo();
#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__)
const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI"
bool isIntelSpecifiedPlatform = false;
const int kVendorID_IntelSpecifiedPlatformIDs[3] = {
// ExtendedModel, ExtendedFamily, Family Code, and Model Number
0xa06a, // MTL
0xc065, // ARL-H
0xb065 // ARL-U
};
int regs_leaf0[4];
int regs_leaf1[4];
__cpuid(regs_leaf0, 0);
__cpuid(regs_leaf1, 0x1);
auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]);
for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) {
if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) {
isIntelSpecifiedPlatform = true;
}
}
if (isIntel) {
if (isIntelSpecifiedPlatform) {
// We want to exclude cores without an LLC
return cores.LLCCores;
} else {
return cores.PhysicalCores;
}
}
#endif
return cores.LLCCores;
}

Просмотреть файл

@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node);
if (cann_kernel_def == nullptr) {
LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
<< " node name: " << node.Name();
LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
<< " node name: " << node.Name();
continue;
}
candidates.push_back(node.Index());
}
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates);
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger());
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
continue;

Просмотреть файл

@ -151,7 +151,7 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu
return false;
}
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
// To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) {

Просмотреть файл

@ -133,9 +133,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu
return false;
}
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
// to pass https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1563483&view=logs&j=f7cc61a9-cc70-56e7-b06c-4668ca17e426
// ReductionOpTest.ReduceSum_half_bert
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
// skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros
int32_t input_type;
GetType(*input_defs[0], input_type, logger);
if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {

Просмотреть файл

@ -13,6 +13,10 @@
#include "core/optimizer/initializer.h"
#include "core/providers/cpu/tensor/unsqueeze.h"
#ifdef __APPLE__
#include <TargetConditionals.h>
#endif
namespace onnxruntime {
namespace coreml {
@ -54,32 +58,50 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
}
}
#if defined(COREML_ENABLE_MLPROGRAM)
void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder,
const Node& node, const logging::Logger& logger) {
const auto& input_defs(node.InputDefs());
TensorShapeVector axes;
GetAxes(model_builder, node, axes);
std::vector<int64_t> input_shape;
GetShape(*input_defs[0], input_shape, logger);
auto op = model_builder.CreateOperation(node, "reshape");
AddOperationInput(*op, "x", input_defs[0]->Name());
TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
}
#endif
Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
[[maybe_unused]] const logging::Logger& logger) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
const auto& input_defs(node.InputDefs());
auto* coreml_squeeze = layer->mutable_squeeze();
TensorShapeVector axes;
GetAxes(model_builder, node, axes);
std::vector<int64_t> input_shape;
GetShape(*input_defs[0], input_shape, logger);
#if defined(COREML_ENABLE_MLPROGRAM)
const auto& input_defs(node.InputDefs());
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "reshape";
#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64
// expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input
if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) {
HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger);
return Status::OK();
}
#endif
std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims";
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
AddOperationInput(*op, "x", input_defs[0]->Name());
if (coreml_op_type == "squeeze") {
if (!axes.empty()) {
// coreml squeeze op does support negative axes
AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
}
} else {
TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
if (!axes.empty()) {
// coreml supports negative axes
AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
}
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));

Просмотреть файл

@ -408,7 +408,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
: graph_viewer_(graph_viewer),
logger_(logger),
coreml_version_(coreml_version),
coreml_compute_unit_(coreml_options.ComputeUnits()),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
onnx_input_names_(std::move(onnx_input_names)),
@ -989,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
get_sanitized_io_info(std::move(input_output_info_)),
std::move(scalar_outputs_),
std::move(int64_outputs_),
logger_, coreml_compute_unit_);
logger_, coreml_options_);
} else
#endif
{
@ -999,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
std::move(input_output_info_),
std::move(scalar_outputs_),
std::move(int64_outputs_),
logger_, coreml_compute_unit_);
logger_, coreml_options_);
}
return model->LoadModel(); // load using CoreML API, including compilation

Просмотреть файл

@ -7,6 +7,7 @@
#include "core/graph/graph_viewer.h"
#include "core/providers/coreml/builders/coreml_spec.h"
#include "core/providers/coreml/model/model.h"
#include "core/providers/coreml/coreml_options.h"
#if defined(COREML_ENABLE_MLPROGRAM)
// coremltools classes
@ -22,8 +23,6 @@ class StorageWriter;
#endif
namespace onnxruntime {
class CoreMLOptions;
namespace coreml {
class IOpBuilder;
@ -218,7 +217,7 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;
const int32_t coreml_version_;
const uint32_t coreml_compute_unit_;
CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel

Просмотреть файл

@ -63,11 +63,14 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
std::unordered_set<std::string> valid_options = {
const std::unordered_set<std::string_view> valid_options = {
kCoremlProviderOption_MLComputeUnits,
kCoremlProviderOption_ModelFormat,
kCoremlProviderOption_RequireStaticInputShapes,
kCoremlProviderOption_EnableOnSubgraphs,
kCoremlProviderOption_SpecializationStrategy,
kCoremlProviderOption_ProfileComputePlan,
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
};
// Validate the options
for (const auto& option : options) {
@ -90,6 +93,16 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
require_static_shape_ = option.second == "1";
} else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) {
enable_on_subgraph_ = option.second == "1";
} else if (kCoremlProviderOption_SpecializationStrategy == option.first) {
if (option.second != "Default" && option.second != "FastPrediction") {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second,
". Valid values are Default and FastPrediction.");
}
strategy_ = option.second;
} else if (kCoremlProviderOption_ProfileComputePlan == option.first) {
profile_compute_plan_ = option.second == "1";
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
}
}
}

Просмотреть файл

@ -14,6 +14,9 @@ class CoreMLOptions {
bool create_mlprogram_{false};
bool enable_on_subgraph_{false};
uint32_t compute_units_{0};
std::string strategy_;
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
public:
explicit CoreMLOptions(uint32_t coreml_flags);
@ -25,6 +28,9 @@ class CoreMLOptions {
bool CreateMLProgram() const { return create_mlprogram_; }
bool EnableOnSubgraph() const { return enable_on_subgraph_; }
uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; }
bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; }
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }
private:
void ValidateAndParseProviderOption(const ProviderOptions& options);

Просмотреть файл

@ -18,6 +18,7 @@
#endif
namespace onnxruntime {
class CoreMLOptions;
namespace coreml {
class Execution;
@ -53,7 +54,7 @@ class Model {
std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info,
std::unordered_set<std::string>&& scalar_outputs,
std::unordered_set<std::string>&& int64_outputs,
const logging::Logger& logger, uint32_t coreml_compute_unit);
const logging::Logger& logger, const CoreMLOptions& coreml_options);
~Model();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model);

Просмотреть файл

@ -25,6 +25,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/objc_str_utils.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/coreml/coreml_options.h"
// force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need
// to manually do this
@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
return Status::OK();
}
// since __clang_major__ >= 15, MLComputePlan is introduced in <CoreML/CoreML.h>
// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`.
// The macro API_AVAILABLE should also be fine.
// Otherwise, the compiler will complain `MLComputePlan` is not defined.
// we define __clang_analyzer__ here is for bypass static analysis
void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) {
#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
if (@available(macOS 14.4, iOS 17.4, *)) {
[MLComputePlan loadContentsOfURL:compileUrl
configuration:config
completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) {
if (!computePlan) {
NSLog(@"Error loading compute plan: %@", error);
// Handle error.
return;
}
MLModelStructureProgram* program = computePlan.modelStructure.program;
if (!program) {
NSLog(@"Error loading program from compute plan., this is not a mlprogram model");
return;
}
MLModelStructureProgramFunction* mainFunction = program.functions[@"main"];
if (!mainFunction) {
NSLog(@"Error loading main function from program");
return;
}
NSArray<MLModelStructureProgramOperation*>* operations = mainFunction.block.operations;
NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count);
for (MLModelStructureProgramOperation* operation in operations) {
// Get the compute device usage for the operation.
MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation];
id<MLComputeDeviceProtocol> preferredDevice = computeDeviceUsage.preferredComputeDevice;
// Get the estimated cost of executing the operation.
MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation];
if (![operation.operatorName isEqualToString:@"const"]) {
NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight);
}
}
}];
} else {
NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API");
}
#endif
}
// Internal Execution class
// This class is part of the model class and handles the calls into CoreML. Specifically, it performs
// 1. Compile the model by given path for execution
@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
// 3. The compiled model will be removed in dealloc or removed using cleanup function
class Execution {
public:
Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags);
Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options);
~Execution();
Status LoadModel();
@ -320,13 +368,13 @@ class Execution {
NSString* coreml_model_path_{nil};
NSString* compiled_model_path_{nil};
const logging::Logger& logger_;
uint32_t coreml_compute_unit_{0};
CoreMLOptions coreml_options_;
MLModel* model_{nil};
};
Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit)
Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options)
: logger_(logger),
coreml_compute_unit_(coreml_compute_unit) {
coreml_options_(coreml_options) {
@autoreleasepool {
coreml_model_path_ = util::Utf8StringToNSString(path.c_str());
}
@ -395,17 +443,41 @@ Status Execution::LoadModel() {
compiled_model_path_ = [compileUrl path];
MLModelConfiguration* config = [[MLModelConfiguration alloc] init];
if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) {
uint32_t coreml_compute_unit = coreml_options_.ComputeUnits();
if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) {
config.computeUnits = MLComputeUnitsCPUOnly;
} else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) {
} else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) {
config.computeUnits = MLComputeUnitsCPUAndGPU;
} else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
} else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
} else {
config.computeUnits = MLComputeUnitsAll;
}
if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) {
config.allowLowPrecisionAccumulationOnGPU = YES;
}
// Set the specialization strategy to FastPrediction for macOS 10.15+
// since __clang_major__ >= 15, optimizationHints is introduced in <CoreML/CoreML.h>
// Same as above comments for why we are checking __clang_major__.
// we define __clang_analyzer__ here is for bypass static analysis
#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
if (HAS_COREML8_OR_LATER) {
MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init];
if (coreml_options_.UseStrategy("FastPrediction")) {
optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction;
config.optimizationHints = optimizationHints;
} else if (coreml_options_.UseStrategy("Default")) {
optimizationHints.specializationStrategy = MLSpecializationStrategyDefault;
config.optimizationHints = optimizationHints;
}
}
#endif
if (coreml_options_.ProfileComputePlan()) {
ProfileComputePlan(compileUrl, config);
}
model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error];
if (error != nil || model_ == nil) {
@ -524,8 +596,8 @@ Model::Model(const std::string& path,
std::unordered_set<std::string>&& scalar_outputs,
std::unordered_set<std::string>&& int64_outputs,
const logging::Logger& logger,
uint32_t coreml_flags)
: execution_(std::make_unique<Execution>(path, logger, coreml_flags)),
const CoreMLOptions& coreml_options)
: execution_(std::make_unique<Execution>(path, logger, coreml_options)),
model_input_names_(std::move(model_input_names)),
model_output_names_(std::move(model_output_names)),
input_output_info_(std::move(input_output_info)),

Просмотреть файл

@ -4,6 +4,7 @@
#include "core/providers/coreml/model/model.h"
namespace onnxruntime {
class CoreMLOptions;
namespace coreml {
class Execution {};
@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/,
std::unordered_set<std::string>&& scalar_outputs,
std::unordered_set<std::string>&& int64_outputs,
const logging::Logger& /*logger*/,
uint32_t /*coreml_flags*/)
const CoreMLOptions& /*coreml_flags*/)
: execution_(std::make_unique<Execution>()),
model_input_names_(std::move(model_input_names)),
model_output_names_(std::move(model_output_names)),

Просмотреть файл

@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)

Просмотреть файл

@ -62,7 +62,8 @@ namespace Dml
const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
const auto kernel_lookup = onnxruntime::KernelLookup{provider_type,
gsl::make_span(&registry, 1),
kernel_type_str_resolver};
kernel_type_str_resolver,
logger};
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos;
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;

Просмотреть файл

@ -54,7 +54,8 @@ namespace Dml
const auto kernelLookup = onnxruntime::KernelLookup(
providerType,
gsl::make_span(&registry, 1),
kernelTypeStrResolver);
kernelTypeStrResolver,
logger);
onnxruntime::GraphViewer graphViewer(graph);
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();

Просмотреть файл

@ -95,7 +95,7 @@ namespace Dml
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
{
#ifdef ENABLE_GRAPH_COMPILATION
return m_impl->GetCapability(graph, kernel_lookup);
return m_impl->GetCapability(graph, kernel_lookup, *GetLogger());
#else
return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup);
#endif
@ -876,7 +876,8 @@ namespace Dml
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
ExecutionProviderImpl::GetCapability(
const onnxruntime::GraphViewer& graph,
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
const onnxruntime::logging::Logger& logger) const
{
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
@ -900,7 +901,7 @@ namespace Dml
}
// Get the list of nodes that should stay on the CPU
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes);
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger);
for (size_t nodeIndex : toplogicalOrder)
{

Просмотреть файл

@ -88,7 +88,8 @@ namespace Dml
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
GetCapability(
const onnxruntime::GraphViewer& graph,
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
const onnxruntime::logging::Logger& logger
) const;
uint32_t GetSupportedDeviceDataTypeMask() const;

Просмотреть файл

@ -818,7 +818,7 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
candidates.push_back(node.Index());
tenative_candidates.push_back(node.Index());
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0) {

Просмотреть файл

@ -32,8 +32,16 @@ namespace nnapi {
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
gsl::span<const DeviceWrapper> nnapi_target_devices,
TargetDeviceOption target_device_option)
: nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique<Model>(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) {
TargetDeviceOption target_device_option,
const logging::Logger& logger)
: nnapi_(nnapi_handle),
graph_viewer_(graph_viewer),
nnapi_model_{std::make_unique<Model>(nnapi_handle)},
shaper_{graph_viewer},
nnapi_target_devices_(nnapi_target_devices),
target_device_option_(target_device_option),
nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)),
logger_(logger) {
nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_;
}
@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
}
void ModelBuilder::PreprocessNodeUnits() {
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
}
// Help to get all quantized operators' input and the NodeUnit(s) using the input

Просмотреть файл

@ -14,7 +14,9 @@
struct NnApi;
namespace onnxruntime {
namespace logging {
class Logger;
}
class GraphViewer;
enum class DataLayout;
class NodeUnit;
@ -31,7 +33,8 @@ class ModelBuilder {
using Shape = Shaper::Shape;
ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option);
gsl::span<const DeviceWrapper> nnapi_target_devices, TargetDeviceOption target_device_option,
const logging::Logger& logger);
common::Status Compile(std::unique_ptr<Model>& model);
@ -173,6 +176,9 @@ class ModelBuilder {
// <1,1> <1,2> <1,3>
InlinedVector<std::pair<size_t, int32_t>> operations_recorder_;
#endif
const logging::Logger& logger_;
// Convert the ONNX model to ANeuralNetworksModel
common::Status Prepare();

Просмотреть файл

@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {}
std::vector<std::unique_ptr<ComputeCapability>>
NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/) const {
const auto& logger = *GetLogger();
std::vector<std::unique_ptr<ComputeCapability>> result;
// TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators)
@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL;
#endif
}();
LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
const nnapi::OpSupportCheckParams params{
android_feature_level,
@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
};
if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) {
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
<< params.android_feature_level
<< "] is lower than minimal supported NNAPI API feature level ["
<< ORT_NNAPI_MIN_API_LEVEL
@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// This holds the result of whether a NodeUnit is supported or not,
// to prevent nodes in a NodeUnit to be checked for multiple times
@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
node_unit_supported_result[node_unit] = supported;
}
LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported
LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
// If the graph is partitioned in multiple subgraphs, and this may impact performance,
// we want to give users a summary message at warning level.
if (num_of_partitions > 1) {
LOGS_DEFAULT(WARNING) << summary_msg;
LOGS(logger, WARNING) << summary_msg;
} else {
LOGS_DEFAULT(INFO) << summary_msg;
LOGS(logger, INFO) << summary_msg;
}
return result;
@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context,
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
using namespace android::nn::wrapper;
const auto& logger = *GetLogger();
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
Node& fused_node = fused_node_and_graph.fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_);
nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger);
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);

Просмотреть файл

@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
// Get all the NodeUnits in the graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger);
std::unordered_set<const NodeUnit*> seen_node_units;
const auto& node_indices = src_graph.GetNodesInTopologicalOrder();

Просмотреть файл

@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models) {
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
main_context_node.Name(),
qnn_models);
qnn_models,
max_spill_fill_size);
}
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
main_context_node.Name(),
qnn_models);
qnn_models,
max_spill_fill_size);
}
Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
uint32_t total_context_size,
int64_t& max_spill_fill_size,
std::vector<int>& main_context_pos_list) {
max_spill_fill_size = 0;
int max_size_index = 0;
for (uint32_t i = 0; i < total_context_size; ++i) {
auto index = main_context_pos_list[i];
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph);
ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*ep_context_node);
int64_t max_size = node_helper.Get(MAX_SIZE, static_cast<int64_t>(0));
if (max_size > max_spill_fill_size) {
max_spill_fill_size = max_size;
max_size_index = i;
}
}
if (0 != max_size_index) {
int tmp_index = main_context_pos_list[0];
main_context_pos_list[0] = main_context_pos_list[max_size_index];
main_context_pos_list[max_size_index] = tmp_index;
}
return Status::OK();
}
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger) {
const logging::Logger& logger,
int64_t max_spill_fill_size) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
qnn_models);
qnn_models, max_spill_fill_size);
// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger) {
auto& graph = model->MainGraph();
@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model,
}
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
ep_node.AddAttribute(MAX_SIZE, static_cast<int64_t>(max_spill_fill_buffer_size));
}
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));

Просмотреть файл

@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
static const std::string MAX_SIZE = "max_size";
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);
@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models);
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size);
Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
uint32_t total_context_size,
int64_t& max_spill_fill_size,
std::vector<int>& main_context_pos_list);
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger);
const logging::Logger& logger,
int64_t max_spill_fill_size);
Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime

Просмотреть файл

@ -8,6 +8,7 @@
#include <string>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() {
}
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
QnnHtpContext_CustomConfig_t customConfig;
customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
context_config_weight_sharing.customConfig = &customConfig;
context_config_weight_sharing.customConfig = &custom_config;
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
@ -615,9 +616,9 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
return context_buffer;
}
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
QnnModelLookupTable& qnn_models) {
Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
uint64_t buffer_length,
uint64_t& max_spill_fill_buffer_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
@ -638,7 +639,69 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
// binary_info life cycle is here
// Binary info to graph info
// retrieve Qnn graph infor from binary info
// retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
graph_count = binary_info->contextBinaryInfoV3.numGraphs;
graphs_info = binary_info->contextBinaryInfoV3.graphs;
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
graph_count = binary_info->contextBinaryInfoV2.numGraphs;
graphs_info = binary_info->contextBinaryInfoV2.graphs;
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
graph_count = binary_info->contextBinaryInfoV1.numGraphs;
graphs_info = binary_info->contextBinaryInfoV1.graphs;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
}
for (uint32_t i = 0; i < graph_count; ++i) {
if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
auto htp_graph_info = reinterpret_cast<QnnHtpSystemContext_GraphBlobInfo_t*>(graphs_info[i].graphInfoV3.graphBlobInfo);
if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size;
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
}
} else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
}
}
LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed.";
return Status::OK();
}
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
ORT_RETURN_IF(result, "Failed to get valid function pointer.");
QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
static_cast<void*>(buffer),
buffer_length,
&binary_info,
&binary_info_size);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
// binary_info life cycle is here
// Binary info to graph info
// retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};
// Register spill fill buffer for multi context
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
// The spill fill buffer is available since 2.28, API version starts from 2.21
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
QnnHtpContext_GroupRegistration_t group_info;
size_t current_contexts_size = GetQnnContextSize();
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
// note that we already move the context with max spill fill size to the beginning of the list
group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
custom_config.groupRegistration = group_info;
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
spill_fill_config.customConfig = &custom_config;
#endif
QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
Qnn_ContextHandle_t context = nullptr;
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
buffer_length,
&context,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
if (1 == graph_count) {
// in case the EPContext node is generated from script
@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
return Status::OK();
}
Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
// need to load system lib if load from Qnn context binary
// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
bool load_from_cached_context,
bool need_load_system_lib) {
std::lock_guard<std::mutex> lock(logger_mutex_);
if (backend_setup_completed_) {
LOGS(logger, VERBOSE) << "Backend setup already!";
@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
LOGS(logger, VERBOSE) << "LoadBackend succeed.";
if (load_from_cached_context) {
if (load_from_cached_context || need_load_system_lib) {
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
}
@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
return Status::OK();
}
void QnnBackendManager::Split(std::vector<std::string>& split_string,
const std::string& tokenized_string,
const char separator) {
split_string.clear();
std::istringstream tokenized_string_stream(tokenized_string);
while (!tokenized_string_stream.eof()) {
std::string value;
getline(tokenized_string_stream, value, separator);
if (!value.empty()) {
split_string.push_back(value);
}
}
}
Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);

Просмотреть файл

@ -93,9 +93,10 @@ class QnnBackendManager {
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
int64_t max_spill_fill_size);
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
@ -112,6 +113,10 @@ class QnnBackendManager {
return contexts_[index];
}
size_t GetQnnContextSize() {
return contexts_.size();
}
const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; }
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
@ -145,8 +150,6 @@ class QnnBackendManager {
void ReleaseResources();
void Split(std::vector<std::string>& split_string, const std::string& tokenized_string, const char separator);
Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
@ -163,6 +166,10 @@ class QnnBackendManager {
Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);
Status GetMaxSpillFillBufferSize(unsigned char* buffer,
uint64_t buffer_length,
uint64_t& max_spill_fill_buffer_size);
private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);

Просмотреть файл

@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
// valid throughout the lifetime of the ModelBuilder
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// This name must be same with the EPContext node name
const auto& graph_name = fused_node.Name();

Просмотреть файл

@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}
bool enable_htp_weight_sharing = false;
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
if ("1" == htp_weight_sharing_enabled_pos->second) {
enable_htp_weight_sharing_ = true;
enable_htp_weight_sharing = true;
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
enable_htp_weight_sharing_ = false;
enable_htp_weight_sharing = false;
} else {
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
<< " only 0 or 1 allowed. Set to 0.";
}
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_;
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
}
// Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform
enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map);
model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false,
provider_options_map);
@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
device_id_,
htp_arch,
soc_model,
enable_htp_weight_sharing_);
enable_htp_weight_sharing);
#ifdef _WIN32
auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance();
@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
// It will load the QnnSystem lib if is_qnn_ctx_model=true, and
// delay the Qnn context creation to Compile() using the cached context binary
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model);
// or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage();
return result;
@ -713,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// remove is_qnn_ctx_model related code
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map,
@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
std::vector<int> main_context_pos_list;
ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list));
uint32_t total_context_size = SafeInt<uint32_t>(main_context_pos_list.size());
int64_t max_spill_fill_size = 0;
// Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning
// HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28
if (total_context_size > 1) {
ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size,
max_spill_fill_size, main_context_pos_list));
}
for (auto main_context_pos : main_context_pos_list) {
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
context_cache_path,
qnn_backend_manager_.get(),
qnn_models,
logger));
logger,
max_spill_fill_size));
}
for (auto fused_node_and_graph : fused_nodes_and_graphs) {
@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
// All partitioned graph share single QNN context, included in the same context binary
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
// Get max spill fill buffer size
uint64_t max_spill_fill_buffer_size = 0;
if (enable_spill_fill_buffer_) {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(),
buffer_size,
max_spill_fill_buffer_size));
}
qnn_ep_context_model_ = std::make_unique<Model>("qnn_ep_context_model", false, logger);
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(),
context_buffer.get(),
@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
qnn_models_,
context_cache_path,
qnn_context_embed_mode_,
max_spill_fill_buffer_size,
logger));
}
return Status::OK();

Просмотреть файл

@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
bool enable_htp_weight_sharing_ = false;
int32_t vtcm_size_in_mb_ = 0;
std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
ModelMetadefIdGenerator metadef_id_generator_;
@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider {
uint32_t default_rpc_control_latency_ = 0;
bool enable_HTP_FP16_precision_ = true;
bool share_ep_contexts_ = false;
bool enable_spill_fill_buffer_ = false;
#ifdef _WIN32
onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr;
#endif

Просмотреть файл

@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For ROCM EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)

Просмотреть файл

@ -294,7 +294,8 @@ std::unique_ptr<IDataTransfer> CreateGPUDataTransfer();
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger);
std::string GetEnvironmentVar(const std::string& var_name);
@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<UInt4x2>() {
namespace QDQ {
inline std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer* graph_viewer) {
return g_host->QDQ__GetAllNodeUnits(graph_viewer);
GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) {
return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger);
}
} // namespace QDQ

Просмотреть файл

@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) {
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger) {
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
}
namespace profiling {

Просмотреть файл

@ -202,7 +202,8 @@ struct ProviderHost {
virtual std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) = 0;
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
@ -890,7 +891,7 @@ struct ProviderHost {
virtual std::unique_ptr<Node__EdgeIterator> NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0;
virtual std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0;
QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0;
// Model
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,

Просмотреть файл

@ -34,7 +34,8 @@ namespace onnxruntime {
namespace vsi {
namespace npu {
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) {
GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger)
: graph_viewer_(graph_viewer), logger_(logger) {
Prepare();
context_ = tim::vx::Context::Create();
graph_ = context_->CreateGraph();
@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g
}
bool GraphEP::Prepare() {
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
for (const auto& node_unit : node_unit_holder_) {
auto quant_op_type = util::GetQuantizedOpType(*node_unit);

Просмотреть файл

@ -51,7 +51,7 @@ struct NodeIOInfo {
class GraphEP {
public:
explicit GraphEP(const GraphViewer& graph_viewer);
explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger);
~GraphEP() {}
bool Prepare();
@ -104,6 +104,7 @@ class GraphEP {
// In the form of {input_name, [NodeUnit(s) using the input]}
std::unordered_map<std::string, std::vector<const NodeUnit*>> all_quantized_op_inputs_;
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;
// Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is
// valid throughout the lifetime of the ModelBuilder

Просмотреть файл

@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {}
std::vector<std::unique_ptr<ComputeCapability>>
VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/) const {
const auto& logger = *GetLogger();
std::vector<std::unique_ptr<ComputeCapability>> result;
if (graph_viewer.IsSubgraph()) {
@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
// Get all the NodeUnits in the graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// This holds the result of whether a NodeUnit is supported or not,
// to prevent nodes in a NodeUnit to be checked for multiple times
@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
}
Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
OrtKernelContext* context) {
OrtKernelContext* context,
const logging::Logger& logger) {
Ort::KernelContext ctx(context);
size_t num_in = ctx.GetInputCount();
const size_t num_inputs = graph_ep->GetGraphInputs().size();
@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
}
if (!graph_ep->GetGraph()->Run()) {
LOGS_DEFAULT(ERROR) << "Failed to run graph.";
LOGS(logger, ERROR) << "Failed to run graph.";
}
for (size_t i = 0; i < ctx.GetOutputCount(); i++) {
auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor;
@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep,
Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
const auto& logger = *GetLogger();
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_viewer = fused_node_graph.filtered_graph;
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer);
std::shared_ptr<vsi::npu::GraphEP> graph_ep = std::make_shared<vsi::npu::GraphEP>(graph_viewer, logger);
for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) {
LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#"
<< graph_viewer.IsInitializedTensor(tensor->Name());
auto input = std::make_shared<vsi::npu::GraphIOInfo>();
input->name = tensor->Name();
@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
graph_ep->GetGraphInputs().push_back(input);
}
for (auto tensor : graph_viewer.GetOutputs()) {
LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor);
auto output = std::make_shared<vsi::npu::GraphIOInfo>();
output->name = tensor->Name();
output->is_initializer = false;
@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
if (node != &node_unit.GetNode()) {
continue;
}
LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]";
LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]";
vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit);
}
LOGS_DEFAULT(INFO) << "Verifying graph";
LOGS(logger, INFO) << "Verifying graph";
graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile();
if (!graph_ep->GetCompiled()) {
LOGS_DEFAULT(ERROR) << "Failed to verify graph.";
LOGS(logger, ERROR) << "Failed to verify graph.";
} else {
LOGS_DEFAULT(INFO) << "Graph has been verified successfully.";
LOGS(logger, INFO) << "Graph has been verified successfully.";
}
NodeComputeInfo compute_info;
@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
[graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */,
OrtKernelContext* context) {
std::lock_guard<std::mutex> lock(this->GetMutex());
Status res = ComputeStateFunc(graph_ep.get(), context);
Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger());
return res;
};

Просмотреть файл

@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
1, 8,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
9, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
13, 20,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
21,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime

Просмотреть файл

@ -798,7 +798,8 @@ std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapa
candidates.push_back(node.Index());
tenative_candidates.push_back(node.Index());
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0) {

Просмотреть файл

@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
if (dim.dim_value() == 0) {
if (dim.dim_value() == 0 && !allow_empty_input) {
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
return false;
}

Просмотреть файл

@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,

Просмотреть файл

@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
return false;
}
}

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше