Update XNNPACK to latest version (#18038)

### Description
<!-- Describe your changes. -->
Update XNNPACK to latest version
- adds fp16 kernels and various other improvements
- requires pthreadpool update as well

Most code updates in the XNNPACK EP are to adjust to the new XNNPACK API
- 'setup' is split into 'reshape' and 'setup'
-  some ops use a workspace buffer
   -  copied workspace allocation from XNNPACK unit test code
- some suffixes changed 

Added wrapper for XNNPACK caches to base XNNPACK EP kernel
- simplifies usage
- XNNPACK split out the code and weights caches, but the code cache
isn't currently usable via the public API
- we could use the internal types if we think it's required for
performance reasons. non-trivial though as we'd need to propagate ifdef
values from the XNNPACK build up to the ORT build.
- using XNNPACK internals would also mean we would not be able to
support using a pre-build XNNPACK package
    - not an issue currently
  
Fixed opset registration for internal NHWC domain
- was not being tied to the ONNX version, so nodes inserted by layout
transformation had the incorrect opset
- a number of other places needed updating once this issue was fixed

Remove support for NCHW Resize from XNNPACK EP so it's NHWC only
- we only supported NCHW for fp32,
- doing so adds complexity in multiple places (XNNPACK EP kernel
implementation, layout transformation and transpose optimization)
- unclear if that complexity provides any benefit. can add back if
required by production scenario

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
We're looking at enabling fp16 support for CoreML and NNAPI. If we do
that we need a good fallback story if the CPU EP will be used. The
XNNPACK fp16 kernels will hopefully provide that.

NOTE: This PR doesn't add fp16 support to the XNNPACK EP kernels. That
can be done as required in separate EPs and should be relatively simple
to do.
This commit is contained in:
Scott McKay 2023-11-04 02:04:28 +10:00 коммит произвёл GitHub
Родитель e36d003765
Коммит 4f2096be38
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
45 изменённых файлов: 814 добавлений и 537 удалений

Просмотреть файл

@ -3,10 +3,15 @@
#The columns are separated by ";" because a list in cmake is just a ";" separated group of strings.
#Names should be in lower case. They will be used as variable names in cmake.
#URLs can be either https URLs or local file paths in cmake-style(directory separator is a forward slash character).
#SHA1 hashes can be generated by running sha1sum command.
#SHA1 hashes can be generated by running sha1sum command on linux. PowerShell can also be used:
# (Get-FileHash -Algorithm SHA1 <filename>).Hash.ToLower()
#If you need to change abseil's version to a different one, you may also want to update external\abseil-cpp.natvis
#since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would
#not affect built binaries.
#
# NOTE: You must run deps_update_and_upload.py when ready to test your changes in a CI.
# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29
#
abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
@ -18,7 +23,7 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc
googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
@ -35,7 +40,7 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v
protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617
protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef
psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013
pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/1787867f6183f056420e532eec640cba25efafea.zip;e43e80781560c5ab404a4da20f34d846f5f5d101
pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip;07a0aa91dd9bf86f31b95497e00f31d8a261a4bd
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13
pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip;85da3caa60eb2b148613b443fbc2bfdc30689965
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374

42
cmake/external/xnnpack.cmake поставляемый
Просмотреть файл

@ -25,17 +25,23 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR})
FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool})
onnxruntime_fetchcontent_makeavailable(pthreadpool)
FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch)
FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch
)
onnxruntime_fetchcontent_makeavailable(googlexnnpack)
set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR})
set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include)
set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool)
# the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
# See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels
message("Adding WebAssembly Source Files to XNNPACK")
set(wasm_srcs "")
file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config)
# Replace newlines with semicolon so that it is treated as a list by CMake
@ -70,25 +76,23 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(${target_srcs} ${bazel_srcs} PARENT_SCOPE)
endfunction()
GetSrcListFromBazel("PROD_SCALAR_WASM_MICROKERNEL_SRCS" prod_scalar_wasm_srcs)
GetSrcListFromBazel("ALL_WASM_MICROKERNEL_SRCS" all_wasm_srcs)
GetSrcListFromBazel("WASM32_ASM_MICROKERNEL_SRCS" wasm32_asm_srcs)
GetSrcListFromBazel("OPERATOR_SRCS" operator_srcs)
GetSrcListFromBazel("TABLE_SRCS" table_srcs)
list(APPEND wasm_srcs ${operator_srcs} ${table_srcs})
message(DEBUG "prod_scalar_wasm_srcs: ${prod_scalar_wasm_srcs}\n")
message(DEBUG "all_wasm_srcs: ${all_wasm_srcs}\n")
message(DEBUG "wasm32_asm_srcs: ${wasm32_asm_srcs}\n")
message("Adding WebAssembly Source Files to XNNPACK")
set(wasm_srcs "")
list(APPEND wasm_srcs ${prod_scalar_wasm_srcs})
list(APPEND wasm_srcs ${all_wasm_srcs})
list(APPEND wasm_srcs ${wasm32_asm_srcs})
target_sources(XNNPACK PRIVATE ${wasm_srcs})
# kernels
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c)
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c)
if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
GetSrcListFromBazel("ALL_WASMSIMD_MICROKERNEL_SRCS" all_wasmsimd_srcs)
message(DEBUG "all_wasmsimd_srcs: ${all_wasmsimd_srcs}")
target_sources(XNNPACK PRIVATE ${all_wasmsimd_srcs})
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
target_compile_options(XNNPACK PRIVATE "-msimd128")
endif()
message(DEBUG "wasm_srcs: ${wasm_srcs}\n")
target_sources(XNNPACK PRIVATE ${wasm_srcs})
# add flags from BAZEL.build
target_compile_options(XNNPACK PRIVATE "-fno-fast-math")
target_compile_options(XNNPACK PRIVATE "-fno-math-errno")
endif()

Просмотреть файл

@ -15,7 +15,8 @@
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool
flatbuffers::flatbuffers Boost::mp11 safeint_interface
)
add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
@ -35,4 +36,4 @@
# there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline'
if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32)
endif()
endif()

Просмотреть файл

@ -41,7 +41,7 @@ function(AddTest)
if (MSVC)
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd6330>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6330>")
#Abseil has a lot of C4127/C4324 warnings.
#Abseil has a lot of C4127/C4324 warnings.
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4127>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4127>")
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4324>"
@ -201,8 +201,18 @@ function(AddTest)
list(APPEND TEST_NODE_FLAGS "--experimental-wasm-simd")
endif()
# prefer Node from emsdk so the version is more deterministic
if (DEFINED ENV{EMSDK_NODE})
set(NODE_EXECUTABLE $ENV{EMSDK_NODE})
else()
# warning as we don't know what node version is being used and whether things like the TEST_NODE_FLAGS
# will be valid. e.g. "--experimental-wasm-simd" is not valid with node v20 or later.
message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.")
set(NODE_EXECUTABLE node)
endif()
add_test(NAME ${_UT_TARGET}
COMMAND node ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS}
COMMAND ${NODE_EXECUTABLE} ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS}
WORKING_DIRECTORY $<TARGET_FILE_DIR:${_UT_TARGET}>
)
endif()

Просмотреть файл

@ -192,8 +192,13 @@ else()
onnxruntime_util
re2::re2
)
set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")
if (onnxruntime_USE_XNNPACK)
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'")
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ALLOW_TABLE_GROWTH=1")
endif()
if(onnxruntime_USE_WEBNN)
@ -204,7 +209,6 @@ else()
target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard)
endif()
set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']")
if (onnxruntime_USE_JSEP)
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
else()
@ -212,7 +216,7 @@ else()
endif()
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}"
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
"SHELL:-s MAXIMUM_MEMORY=4294967296"
"SHELL:-s EXIT_RUNTIME=0"

Просмотреть файл

@ -1,66 +1,27 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d53c48aa1..77c3cf983 100755
index dba9b4687..bcaa18ad7 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -105,22 +105,12 @@ ENDIF()
@@ -122,7 +122,7 @@ ENDIF()
# ---[ Build flags
IF(NOT CMAKE_SYSTEM_NAME)
MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined")
-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS)$")
+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS|Emscripten|iOS)$")
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}")
-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT)$")
+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT|Emscripten|iOS)$")
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"")
ENDIF()
# ---[ Download deps
IF(NOT XNNPACK_USE_SYSTEM_LIBS)
- IF(NOT DEFINED CLOG_SOURCE_DIR)
- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)")
- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt")
- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download")
- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build .
- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download")
- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory")
- ENDIF()
-
IF(NOT DEFINED CPUINFO_SOURCE_DIR)
MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)")
CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt")
@@ -7108,6 +7098,10 @@ IF(MSVC)
SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O2 >")
SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O2 >")
SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O1 >")
+ELSEIF(CMAKE_GENERATOR STREQUAL Xcode)
+ TARGET_COMPILE_OPTIONS(all_microkernels PRIVATE $<$<NOT:$<CONFIG:Debug>>: -O2 >)
+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$<NOT:$<CONFIG:Debug>>: -O2 >)
+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$<NOT:$<CONFIG:Debug>>: -Os >)
ELSE()
SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: -O2 >")
SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: -O2 >")
@@ -7142,26 +7136,6 @@ IF(LIBM)
TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM})
IF(CMAKE_SYSTEM_NAME MATCHES "Windows")
@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY)
TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging)
TARGET_LINK_LIBRARIES(post-operation PRIVATE logging)
TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run)
- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
+ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph)
+ ELSE()
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
+ ENDIF()
SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES)
ENDIF()
-# ---[ Configure clog
-IF(NOT TARGET clog)
- IF(NOT XNNPACK_USE_SYSTEM_LIBS)
- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "")
- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "")
- ADD_SUBDIRECTORY(
- "${CLOG_SOURCE_DIR}/deps/clog"
- "${CMAKE_BINARY_DIR}/clog")
- # We build static version of clog but a dynamic library may indirectly depend on it
- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON)
- ELSE()
- ADD_LIBRARY(clog STATIC IMPORTED)
- FIND_LIBRARY(CLOG_LIBRARY clog)
- IF(NOT CLOG_LIBRARY)
- MESSAGE(FATAL_ERROR "Cannot find clog")
- ENDIF()
- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}")
- ENDIF()
-ENDIF()
-
# ---[ Configure cpuinfo
IF(NOT TARGET cpuinfo)
IF(NOT XNNPACK_USE_SYSTEM_LIBS)
IF(NOT MSVC)

Просмотреть файл

@ -20,15 +20,15 @@ Do not modify directly.*
| Asinh | ai.onnx(9+) | |
| Atan | ai.onnx(7+) | |
| Atanh | ai.onnx(9+) | |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
| BiasAdd | com.microsoft(1+) | |
| BiasSplitGelu | com.microsoft(1+) | |
| Cast | ai.onnx(6-8,9-12,13-18,19+) | |
| Ceil | ai.onnx(6-12,13+) | |
| Clip | ai.onnx(6-10,11,12,13+) | |
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation |
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; conv3d is not supported; need implementing activation |
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
| Cos | ai.onnx(7+) | |
| Cosh | ai.onnx(9+) | |
| Div | ai.onnx(7-12,13,14+) | |
@ -57,7 +57,7 @@ Do not modify directly.*
| LessOrEqual | ai.onnx(12-15,16+) | |
| Log | ai.onnx(6-12,13+) | |
| MatMul | ai.onnx(1-12,13+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
| MemcpyFromHost | ai.onnx(1+) | |
| MemcpyToHost | ai.onnx(1+) | |
| Mul | ai.onnx(7-12,13,14+) | |
@ -79,7 +79,7 @@ Do not modify directly.*
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
| Relu | ai.onnx(6-12,13,14+) | |
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
| Sigmoid | ai.onnx(6-12,13+) | |
| Sin | ai.onnx(7+) | |

Просмотреть файл

@ -164,7 +164,10 @@ async function initializeSession(
session = await ort.InferenceSession.create(modelFilePath, sessionConfig);
}
} catch (e) {
Logger.error('TestRunner', `Failed to load model from file: ${modelFilePath}. Error: ${inspect(e)}`);
Logger.error(
'TestRunner',
`Failed to load model from file: ${modelFilePath}. ` +
`Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`);
throw e;
}

Просмотреть файл

@ -62,8 +62,13 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
auto create_error_message = [&node, &status](const std::string& prefix) {
std::ostringstream errormsg;
errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")";
errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). ";
errormsg << prefix;
const auto& domain = node.Domain();
if (!domain.empty()) {
errormsg << domain << ".";
}
errormsg << node.OpType() << "(" << node.SinceVersion() << ")"
<< " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). ";
if (!status.IsOK())
errormsg << status.ErrorMessage();

Просмотреть файл

@ -90,12 +90,16 @@ void RegisterNHWCSchemaWithActivation(const RegistrationFunc& f, ::ONNX_NAMESPAC
void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMESPACE::OpSchema&&)>& fn) {
// if the operator may be fused with an activation, use the WITH_ACTIVATION variant to add optional attributes
// for the activation parameters.
// For now we only register operators from opset 11 on. Models can easily have their opset updated using ONNX tools
// We mainly register operators from opset 11 on . Models can easily have their opset updated using ONNX tools
// so supporting older opsets is unnecessary.
// Older opsets are included on a per-operator basis as needed.
// NOTE: This should be in sync with GetLayoutSensitiveOps in
// /onnxruntime/core/optimizer/transpose_optimization/transpose_optimizer.cc
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 7);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 10);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 11);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 19);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 9);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14);
@ -106,16 +110,18 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, InstanceNormalization, 6);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, Conv, 1);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, Conv, 11);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 11);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 1);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 11);
REGISTER_NHWC_SCHEMA(fn, GlobalAveragePool, 1);
REGISTER_NHWC_SCHEMA(fn, GlobalLpPool, 2);
REGISTER_NHWC_SCHEMA(fn, GlobalMaxPool, 1);
REGISTER_NHWC_SCHEMA(fn, GridSample, 16);
REGISTER_NHWC_SCHEMA(fn, GridSample, 20);
REGISTER_NHWC_SCHEMA(fn, LRN, 1);
REGISTER_NHWC_SCHEMA(fn, LRN, 13);
@ -123,6 +129,9 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
REGISTER_NHWC_SCHEMA(fn, LpPool, 11);
REGISTER_NHWC_SCHEMA(fn, LpPool, 18);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 1);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 8);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 10);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 11);
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 12);

Просмотреть файл

@ -3908,7 +3908,8 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std
// kernel lookup works as per usual, if not using an existing schema.
if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) {
ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node),
"Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType());
"Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType(),
" SinceVersion:", fused_node.SinceVersion());
} else if (IndexedSubGraph::SourceOfSchema::REUSE_OR_CREATE == sub_graph.schema_source) {
auto schema_key = GenerateSchemaKey(sub_graph);
if (reusable_fused_schema_map_.count(schema_key) == 0) {

Просмотреть файл

@ -232,6 +232,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
}
}
// special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported
if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) {
auto onnx_version = domain_to_version.find(kOnnxDomain);
if (onnx_version != domain_to_version.end()) {
domain_to_version[kMSInternalNHWCDomain] = onnx_version->second;
}
}
auto domain_map = allow_official_onnx_release_only_final
? schema_registry->GetLastReleasedOpsetVersions(false)
: schema_registry->GetLatestOpsetVersions(false);

Просмотреть файл

@ -66,17 +66,6 @@ bool ConvertNodeLayout(const api::NodeRef& node) {
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
// handle special cases
#if defined(USE_XNNPACK)
if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) {
if (node.OpType() == "Resize") {
// XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize
// with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through
// the Resize during transpose optimization.
return false;
}
}
#endif
#if defined(USE_JSEP)
// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
if (node.GetExecutionProviderType() == kJsExecutionProvider) {

Просмотреть файл

@ -17,7 +17,7 @@ static bool EPAwareHandleResize(HandlerArgs& args) {
// layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled
// by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs.
const auto ep_type = args.node.GetExecutionProviderType();
if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) {
if (ep_type == kCpuExecutionProvider) {
// allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model
int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
if (rank_int == 4) {

Просмотреть файл

@ -352,7 +352,7 @@ class UpsampleBase {
(scales.size() == 4 && scales[0] == 1 && scales[3] == 1) ||
scales.size() == 3 ||
(scales.size() == 5 && scales[0] == 1 && scales[1] == 1),
"'Linear' mode only support:\n"
"'Linear' mode only supports:\n"
" * 2-D inputs or\n"
" * 3-D inputs ('Bilinear', 'Trilinear') or\n"
" * 4-D inputs with the corresponding outermost 2 scale values being 1"

Просмотреть файл

@ -238,18 +238,40 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Whe
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm);
@ -257,17 +279,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax);
@ -291,11 +302,17 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
@ -304,11 +321,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
@ -322,8 +334,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, LayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, InstanceNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range);
@ -508,18 +521,40 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
@ -527,17 +562,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
@ -575,7 +599,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize)>,
@ -594,8 +618,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range)>,

Просмотреть файл

@ -25,6 +25,13 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Conv<false>);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Conv,
kMSInternalNHWCDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Conv<true>);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Conv,
kOnnxDomain,

Просмотреть файл

@ -24,6 +24,13 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
ConvTranspose<false>);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
ConvTranspose,
kMSInternalNHWCDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
ConvTranspose<true>);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
ConvTranspose,
kOnnxDomain,

Просмотреть файл

@ -52,15 +52,20 @@ namespace js {
Pool<pool_type, is_channels_last>);
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10)
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10)
POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11)
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11)
POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1)
POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1)
POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7)
POOLING_KERNEL_VERSIONED(MaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1, 7)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 8, 9)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 10, 10)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11)
POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12)

Просмотреть файл

@ -51,6 +51,7 @@ namespace js {
REGISTER_RESIZE_KERNEL(domain, 19);
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kOnnxDomain);
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kMSInternalNHWCDomain);
REGISTER_RESIZE_KERNEL_DOMAIN(kOnnxDomain);
REGISTER_RESIZE_KERNEL_DOMAIN(kMSInternalNHWCDomain);

Просмотреть файл

@ -7,22 +7,22 @@
#include "core/common/common.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/common.h"
#include "core/providers/cpu/nn/pool_attributes.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/xnnpack/detail/utils.h"
// each operator provides a helper to check if supported
#include "core/providers/xnnpack/math/gemm.h"
#include "core/providers/xnnpack/math/matmul.h"
#include "core/providers/xnnpack/math/softmax.h"
#include "core/providers/xnnpack/nn/average_pool.h"
#include "core/providers/xnnpack/nn/conv.h"
#include "core/providers/xnnpack/nn/conv_transpose.h"
#include "core/providers/xnnpack/nn/max_pool.h"
#include "core/providers/xnnpack/math/gemm.h"
#include "core/providers/xnnpack/math/matmul.h"
#include "core/providers/xnnpack/nn/average_pool.h"
#include "core/providers/xnnpack/nn/resize.h"
#include "core/providers/xnnpack/nn/softmax.h"
#include "core/providers/xnnpack/tensor/resize.h"
namespace onnxruntime {
namespace xnnpack {

Просмотреть файл

@ -25,7 +25,7 @@ const char* OpTypeToString(OpComputeType opCtype) {
case op_compute_type_fp16:
return "fp16";
case op_compute_type_qs8_per_channel:
return "qc8";
return "qs8_qc8w";
case op_compute_type_qs8:
return "qs8";
case op_compute_type_qu8:

Просмотреть файл

@ -78,7 +78,7 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra
return supported;
}
Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info) {
Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*enable_caches*/ true) {
const auto& node{Node()};
info.GetAttrOrDefault<float>("alpha", &alpha_, 1.f);
@ -146,14 +146,9 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride,
B_->Data<float>(), // const float* kernel,
bias_Data, // const float* bias,
output_min,
output_max,
output_min, output_max,
flags,
#ifdef XNN_CACHE_ENABLE
&xnn_caches_,
#else
0,
#endif
GetCodeCache(), GetWeightsCache(),
&p);
if (status != xnn_status_success) {
@ -165,20 +160,25 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr,
}
Status Gemm::Compute(OpKernelContext* context) const {
pthreadpool_t t_pool = GetThreadPool();
pthreadpool_t threadpool = GetThreadPool();
const auto* A = context->Input<Tensor>(0);
auto Y = context->Output(0, {M_, N_});
// if input is empty tensor, return as nothing need to be calculated and we've set the shape for the output
if (M_ == 0 || N_ == 0)
if (M_ == 0 || N_ == 0) {
return Status::OK();
}
xnn_status status = xnn_setup_fully_connected_nc_f32(
op0_.get(),
trans_A_ == CblasNoTrans ? M_ : K_, // Number of rows to multiply
A->Data<float>(),
Y->MutableData<float>(),
t_pool);
xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(),
// Number of rows to multiply
trans_A_ == CblasNoTrans ? M_ : K_,
threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
}
status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data<float>(), Y->MutableData<float>());
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
@ -192,7 +192,15 @@ Status Gemm::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 12, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 8, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 9, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm);

Просмотреть файл

@ -41,14 +41,6 @@ class Gemm : protected GemmBase, public XnnpackKernel {
float alpha_;
float beta_;
#ifdef XNN_CACHE_ENABLE
#if XNN_PLATFORM_JIT
xnn_code_cache code_cache_;
#endif
xnn_caches xnn_caches_ = {0, 0};
xnn_weights_cache weights_cache_;
#endif
};
} // namespace xnnpack

Просмотреть файл

@ -62,7 +62,7 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g
return supported;
}
MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info) {}
MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ true) {}
Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
@ -99,9 +99,11 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
output_max,
flags,
#ifdef XNN_CACHE_ENABLE
&xnn_caches_,
GetCodeCache(),
GetWeightsCache(),
#else
0,
nullptr,
nullptr,
#endif
&p);
@ -116,7 +118,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
Status MatMul::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
pthreadpool_t t_pool = GetThreadPool();
pthreadpool_t threadpool = GetThreadPool();
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_));
Tensor* y = ctx->Output(0, helper.OutputShape());
@ -126,13 +128,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
auto* y_data = y->MutableData<float>();
xnn_status status = xnn_setup_fully_connected_nc_f32(
op0_.get(),
a->Shape()[0],
a->Data<float>(),
y_data,
t_pool);
xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
}
status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data<float>(), y_data);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
}
@ -144,7 +145,11 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 8, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MatMul);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 9, 12, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MatMul);

Просмотреть файл

@ -32,14 +32,6 @@ class MatMul : public XnnpackKernel {
AllocatorPtr myAlloc;
XnnpackOperator op0_ = nullptr;
#ifdef XNN_CACHE_ENABLE
#if XNN_PLATFORM_JIT
xnn_code_cache code_cache_;
#endif
xnn_caches xnn_caches_ = {0, 0};
xnn_weights_cache weights_cache_;
#endif
};
} // namespace xnnpack

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/xnnpack/nn/softmax.h"
#include "core/providers/xnnpack/math/softmax.h"
#include <utility>
@ -25,6 +25,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph
output_type != TensorTypeUint8) {
break;
}
// to ensure its output scale and zp are 1/256 and 0, otherwise xnnpack EP has to do extra requantization
// idealy, QlinearSoftmax or QDQSoftmax will keep this output scale and zp, but we have to handle some
// qdq models converted from other framework
@ -33,6 +34,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph
if (fabs(q_scale.DataAsSpan<float>()[0] - 1.0f / 256.0f) > 0.0001f) {
break;
}
if (zero_tensor) {
Initializer q_zp(*zero_tensor, node_unit.ModelPath());
if (q_zp.DataAsSpan<uint8_t>()[0] != 0) {
@ -57,6 +59,7 @@ bool Softmax::IsOnnxNodeSupported(const NodeUnit& node_unit,
IsQuantSoftmaxSupported(node_unit, graph) == false) {
return false;
}
// use do {} while(false) so it's easier to set a breakpoint on the return
do {
// SoftMax has 1 input.
@ -133,6 +136,7 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} {
ORT_ENFORCE(status.IsOK(), "opset must be existed in attributes of QlinearSoftmax");
opset_ = gsl::narrow_cast<int>(opset);
}
int64_t axis = -1;
Status status = info.GetAttr<int64_t>("axis", &axis);
// our op checker function has ensured that axis must be the last dim
@ -162,23 +166,22 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} {
if (op_type_ == OpComputeType::op_compute_type_qu8) {
// the order of input tensor, x,x_scale, x_zp, y_scale, y_zp
OpQuantParam quant_param = ParseQuantParamForOp(info, x_dtype, 1);
xstatus = xnn_create_softmax_nc_qu8(
channels,
channels,
channels,
quant_param[0].first[0], // x_scale
quant_param[1].second, // y_zp
quant_param[1].first[0], // y_scale
0, // flags,
&p);
xstatus = xnn_create_softmax_nc_qu8(channels,
channels,
channels,
quant_param[0].first[0], // x_scale
quant_param[1].second, // y_zp
quant_param[1].first[0], // y_scale
0, // flags,
&p);
} else if (op_type_ == OpComputeType::op_compute_type_fp32) {
xstatus = xnn_create_softmax_nc_f32(
channels,
channels,
channels,
0, // flags,
&p);
xstatus = xnn_create_softmax_nc_f32(channels,
channels,
channels,
0, // flags,
&p);
}
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_",
OpTypeToString(op_type_), " failed. Status:", xstatus);
op0_.reset(p);
@ -194,39 +197,48 @@ Status Softmax::Compute(OpKernelContext* ctx) const {
if (X_shape.Size() == 0) {
return Status::OK();
}
pthreadpool_t t_pool = GetThreadPool();
pthreadpool_t threadpool = GetThreadPool();
const size_t N = X_shape.SizeToDimension(axis_);
// const size_t D = X_shape.SizeFromDimension(axis_); // the step D is 1
xnn_status status = xnn_status_invalid_state;
if (op_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_softmax_nc_qu8(
op0_.get(),
N,
X->Data<uint8_t>(),
Y->MutableData<uint8_t>(),
t_pool);
} else {
status = xnn_setup_softmax_nc_f32(
op0_.get(),
N,
X->Data<float>(),
Y->MutableData<float>(),
t_pool);
}
auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8
: xnn_reshape_softmax_nc_f32;
status = reshape_fn(op0_.get(), N, threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_",
OpTypeToString(op_type_), " returned ", status);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_softmax_nc_", OpTypeToString(op_type_),
" returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
if (op_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_softmax_nc_qu8(op0_.get(), X->Data<uint8_t>(), Y->MutableData<uint8_t>());
} else {
status = xnn_setup_softmax_nc_f32(op0_.get(), X->Data<float>(), Y->MutableData<float>());
}
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_", OpTypeToString(op_type_),
" returned ", status);
}
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 12, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Softmax);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Softmax);
ONNX_OPERATOR_KERNEL_EX(Softmax, kOnnxDomain, 13, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Softmax);

Просмотреть файл

@ -2,10 +2,13 @@
// Licensed under the MIT License.
#include "core/providers/xnnpack/nn/average_pool.h"
#include <memory>
#include "core/common/status.h"
#include "core/graph/graph.h"
#include "core/providers/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/providers/xnnpack/detail/utils.h"
namespace onnxruntime {
@ -90,6 +93,10 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit,
const auto& inputs = node_unit.Inputs();
// use do {} while(false) so it's easier to set a breakpoint on the return
do {
if (node_unit.SinceVersion() < 7) {
break;
}
// AveragePool has 1 input.
const auto& x_arg = inputs[0].node_arg;
@ -141,6 +148,11 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit,
break;
}
// need dilations to all be 1
if (!pool_attrs.default_dilations) {
break;
}
supported = true;
} while (false);
@ -221,24 +233,47 @@ Status AveragePool::Compute(OpKernelContext* context) const {
return Status::OK();
}
pthreadpool_t t_pool = GetThreadPool();
xnn_status status = xnn_status_invalid_state;
pthreadpool_t threadpool = GetThreadPool();
// setup allocator/automated dellocate for workspace
size_t workspace_size = 0;
size_t workspace_alignment = 0;
xnn_allocator* allocator = GetStoredAllocator().second;
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
auto reshape_fn = (avgpool_type_ == OpComputeType::op_compute_type_fp32)
? xnn_reshape_average_pooling2d_nhwc_f32
: xnn_reshape_average_pooling2d_nhwc_qu8;
auto status = reshape_fn(op0_.get(), N, H, W,
&workspace_size, &workspace_alignment,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_),
" returned ", status);
}
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
if (avgpool_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), N, H, W,
X.Data<float>(), Y.MutableData<float>(),
t_pool /*threadpool */);
status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(),
X.Data<float>(), Y.MutableData<float>());
} else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), N, H, W,
X.Data<uint8_t>(), Y.MutableData<uint8_t>(),
t_pool /*threadpool */);
status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(),
X.Data<uint8_t>(), Y.MutableData<uint8_t>());
}
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_",
OpTypeToString(avgpool_type_), " returned ", status);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_),
" returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
@ -246,8 +281,26 @@ Status AveragePool::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
AveragePool, kMSInternalNHWCDomain, 7, 9,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
AveragePool);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
AveragePool, kMSInternalNHWCDomain, 10, 10,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
AveragePool);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
AveragePool, kMSInternalNHWCDomain, 11, 18,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
AveragePool);
ONNX_OPERATOR_KERNEL_EX(
AveragePool, kMSInternalNHWCDomain, 11,
AveragePool, kMSInternalNHWCDomain, 19,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
AveragePool);

Просмотреть файл

@ -3,12 +3,13 @@
#include "conv.h"
#include "core/common/gsl.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/common/gsl.h"
namespace onnxruntime {
namespace xnnpack {
@ -64,21 +65,48 @@ Status Conv::Compute(OpKernelContext* context) const {
if (Y->Shape().Size() == 0) {
return Status::OK();
}
pthreadpool_t t_pool = GetThreadPool();
xnn_status status = xnn_status_invalid_state;
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), N, H, W, X.Data<float>(), Y->MutableData<float>(),
t_pool /*threadpool*/);
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), N, H, W, X.Data<int8_t>(), Y->MutableData<int8_t>(),
t_pool /*threadpool*/);
pthreadpool_t threadpool = GetThreadPool();
// setup allocator/automated dellocate for workspace
size_t workspace_size = 0;
size_t workspace_alignment = 0;
xnn_allocator* allocator = GetStoredAllocator().second;
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
auto reshape_fn = xnn_reshape_convolution2d_nhwc_f32;
if (conv_type_ == OpComputeType::op_compute_type_qs8) {
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8;
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), N, H, W, X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
t_pool /*threadpool*/);
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
status = xnn_setup_convolution2d_nhwc_qc8(op0_.get(), N, H, W, X.Data<int8_t>(), Y->MutableData<int8_t>(),
t_pool /*threadpool*/);
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w;
}
auto status = reshape_fn(op0_.get(), N, H, W,
&workspace_size, &workspace_alignment,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_),
"returned ", status);
}
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data<float>(),
Y->MutableData<float>());
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data<int8_t>(),
Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), workspace.get(), X.Data<uint8_t>(),
Y->MutableData<uint8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data<int8_t>(),
Y->MutableData<int8_t>());
}
if (status != xnn_status_success) {
@ -86,7 +114,7 @@ Status Conv::Compute(OpKernelContext* context) const {
OpTypeToString(conv_type_), "returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
@ -94,6 +122,10 @@ Status Conv::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv);
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv);

Просмотреть файл

@ -23,7 +23,8 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
const std::optional<std::pair<float, float>>& clip_min_max,
const Tensor& Weight, const Tensor* Bias,
XnnpackOperator& op_uptr,
xnn_caches_t caches_t,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
const OpQuantParam& quant_param,
OpComputeType conv_type,
bool is_transpose = false) {
@ -75,7 +76,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
C, M, // input channel stride, output channel stride
Weight.Data<float>(), B_data,
foutput_min, foutput_max, flags,
caches_t,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_qs8) {
const float output_scale = quant_param[2].first[0];
@ -99,7 +100,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
quant_param[2].second, quant_param[2].first[0],
output_min, output_max,
flags,
caches_t,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) {
auto* B_data = Bias ? Bias->Data<int32_t>() : nullptr;
@ -107,7 +108,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
const int8_t output_zero_point = quant_param[2].second;
const int8_t output_min = xnn_u8s8_quantize<int8_t>(foutput_min, output_scale, output_zero_point);
const int8_t output_max = xnn_u8s8_quantize<int8_t>(foutput_max, output_scale, output_zero_point);
status = xnn_create_convolution2d_nhwc_qc8(
status = xnn_create_convolution2d_nhwc_qs8_qc8w(
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
kernel_height, kernel_width,
subsampling_height, subsampling_width,
@ -123,7 +124,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
quant_param[2].second, quant_param[2].first[0],
output_min, output_max,
flags,
caches_t,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_qu8) {
const auto* B_data = Bias ? Bias->Data<int32_t>() : nullptr;
@ -148,15 +149,17 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
quant_param[2].second, quant_param[2].first[0],
output_min, output_max,
flags,
caches_t,
code_cache, weights_cache,
&p);
}
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Failed to create xnnpack kernel. xnn_create_",
is_transpose ? "deconvolution2d" : "convolution2d", "_nhwc_",
OpTypeToString(conv_type), " returned ", status);
}
op_uptr.reset(p);
return Status::OK();
}
@ -296,6 +299,11 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&
const onnxruntime::Node& node = node_unit.GetNode();
// use do {} while(false) so it's easier to set a breakpoint on the return
do {
// Internal NHWC domain starts at opset 11
if (node_unit.SinceVersion() < 11) {
break;
}
// Conv has at least 2 inputs.
const auto& inputs = node_unit.Inputs();
const auto& x_arg = inputs[0].node_arg;
@ -367,7 +375,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&
}
ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
: XnnpackKernel(info),
: XnnpackKernel(info, /*enable_caches*/ true),
conv_attrs_(info),
conv_transpose_attrs_(info),
convbase_attrs_ref_(is_transpose ? conv_transpose_attrs_ : conv_attrs_),
@ -383,16 +391,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
}
}
}
// xnnpack cache_code, unfortunately these definitions are only available in xnnpack/cache.h,
#ifdef XNN_CACHE_ENABLE
#if XNN_PLATFORM_JIT
xnn_init_code_cache(&code_cache_);
xnn_caches_.code_cache = &code_cache_;
#endif
// TODO(Jicwen) enable weight-cache and code-cache
xnn_init_weights_cache(&weights_cache_);
xnn_caches_.weights_cache = &weights_cache_;
#endif
const auto& node{Node()};
const auto& input_defs = node.InputDefs();
const NodeArg& X = *input_defs[0];
@ -477,11 +476,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
Status ConvBase::CreateKernel() {
auto ret = CreateXnnpackKernel(&convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_,
B_, op0_,
#ifdef XNN_CACHE_ENABLE
&xnn_caches_,
#else
0,
#endif
GetCodeCache(), GetWeightsCache(),
quant_param_, conv_type_, is_transpose_);
return ret;
}

Просмотреть файл

@ -39,14 +39,6 @@ class ConvBase : public XnnpackKernel {
std::optional<std::pair<float, float>> clip_min_max_;
XnnpackOperator op0_ = nullptr;
// we can't have the definition here because we can't import xnnpack/cache.h
#ifdef XNN_CACHE_ENABLE
#if XNN_PLATFORM_JIT
xnn_code_cache code_cache_;
#endif
xnn_caches xnn_caches_ = {0, 0};
xnn_weights_cache weights_cache_;
#endif
OpQuantParam quant_param_;
OpComputeType conv_type_ = OpComputeType::op_compute_type_invalid;
};

Просмотреть файл

@ -81,29 +81,34 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
if (Y->Shape().Size() == 0) {
return Status::OK();
}
pthreadpool_t t_pool = GetThreadPool();
pthreadpool_t threadpool = GetThreadPool();
auto output_pad_0 = gsl::narrow_cast<uint32_t>(conv_transpose_attrs_.output_padding[0]);
auto output_pad_1 = gsl::narrow_cast<uint32_t>(conv_transpose_attrs_.output_padding[1]);
xnn_status status = xnn_status_invalid_state;
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_deconvolution2d_nhwc_f32(
op0_.get(), N, H, W,
output_pad_0,
output_pad_1, X.Data<float>(), Y->MutableData<float>(),
t_pool /*threadpool*/);
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_deconvolution2d_nhwc_qs8(
op0_.get(), N, H, W,
output_pad_0,
output_pad_1, X.Data<int8_t>(), Y->MutableData<int8_t>(),
t_pool /*threadpool*/);
auto reshape_fn = xnn_reshape_deconvolution2d_nhwc_f32;
if (conv_type_ == OpComputeType::op_compute_type_qs8) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8;
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_deconvolution2d_nhwc_qu8(
op0_.get(), N, H, W,
output_pad_0,
output_pad_1, X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
t_pool /*threadpool*/);
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8;
}
status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_deconvolution2d_nhwc_",
OpTypeToString(conv_type_), " returned ", status);
}
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_deconvolution2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
}
if (status != xnn_status_success) {
@ -111,7 +116,7 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
OpTypeToString(conv_type_), " returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}

Просмотреть файл

@ -41,6 +41,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit,
const onnxruntime::Node& node = node_unit.GetNode();
// use do {} while(false) so it's easier to set a breakpoint on the return
do {
if (node_unit.SinceVersion() < 8) {
break;
}
// MaxPool has 1 input.
auto input_defs = node.InputDefs();
const auto& x_arg = *input_defs[0];
@ -220,20 +224,29 @@ Status MaxPool::Compute(OpKernelContext* context) const {
return Status::OK();
}
pthreadpool_t t_pool = GetThreadPool();
xnn_status status = xnn_status_invalid_state;
pthreadpool_t threadpool = GetThreadPool();
auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32;
if (maxpool_type_ == OpComputeType::op_compute_type_qu8)
reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8;
else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8;
}
auto status = reshape_fn(op0_.get(), N, H, W,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_max_pooling2d_nhwc_",
OpTypeToString(maxpool_type_), " returned ", status);
}
if (maxpool_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), N, H, W,
X.Data<float>(), Y->MutableData<float>(),
t_pool /*threadpool */);
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
} else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), N, H, W,
X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
t_pool /*threadpool */);
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
} else {
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), N, H, W,
X.Data<int8_t>(), Y->MutableData<int8_t>(),
t_pool /*threadpool */);
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
}
if (status != xnn_status_success) {
@ -241,7 +254,7 @@ Status MaxPool::Compute(OpKernelContext* context) const {
OpTypeToString(maxpool_type_), " returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
@ -249,12 +262,24 @@ Status MaxPool::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
MaxPool);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
MaxPool);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
MaxPool);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
MaxPool);
ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/xnnpack/nn/resize.h"
#include "core/providers/xnnpack/tensor/resize.h"
#include <algorithm>
#include <utility>
@ -10,6 +10,7 @@
#include "core/common/inlined_containers_fwd.h"
#include "core/framework/op_kernel.h"
#include "core/optimizer/initializer.h"
#include "core/providers/xnnpack/xnnpack_init.h"
namespace onnxruntime {
namespace xnnpack {
@ -18,26 +19,67 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
const GraphViewer& graph_viewer) {
bool supported = false;
do {
if (node_unit.SinceVersion() < 10) {
break;
}
// Resize has 1-4 input.
const auto& inputs = node_unit.Inputs();
const auto& x_arg = inputs[0].node_arg;
const auto* x_type = x_arg.TypeAsProto();
if (x_type == nullptr ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
break;
}
const auto* x_shape = x_arg.Shape();
//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) or
// 4-D input with outermost and innermost scales as 1 (NHWC)
// but we just support 4-d tensor for now, and the channel must be known.
// 'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) can be supported.
// we only support 4-d tensor for now, and the channel must be known.
// we assume the input in NCHW for this test.
if (!x_shape || x_shape->dim_size() != 4 || x_shape->dim(1).dim_value() <= 0) {
break;
}
// validate it is in fact NCHW
//
// opset 10 had `scales` as input 1 and no sizes. later opsets added roi as input 1 followed by scales and sizes.
auto opset_version = node_unit.SinceVersion();
size_t scale_idx = opset_version == 10 ? 1 : 2;
size_t size_idx = 3;
// onnx shape inferencing validates that one and not both of sizes and scales are provided
const auto* scale_tensor = inputs.size() >= scale_idx + 1
? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true)
: nullptr;
const auto* size_tensor = opset_version > 10 && inputs.size() >= size_idx + 1
? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true)
: nullptr;
// if both scales and sizes are nullptr the one that was provided was not a constant initializer
if (!scale_tensor && !size_tensor) {
break;
}
// check the scale for the second dim is 1 or the size of the second dim matches the input shape.
// if not, it is not the C dim as a Resize will not change the number of channels.
InlinedVector<float> scale(4, 1.0F);
if (scale_tensor) {
const Initializer scale_val(*scale_tensor, node_unit.ModelPath());
if (scale_val.DataAsSpan<float>()[1] != 1.0F) {
break;
}
}
if (size_tensor) {
const Initializer size_val(*size_tensor, node_unit.ModelPath());
if (size_val.DataAsSpan<int64_t>()[1] != x_shape->dim(1).dim_value()) {
break;
}
}
const auto* output_shape = node_unit.Outputs()[0].node_arg.Shape();
bool length_resized_compatible_pytorch_half_pixel = true;
// when length_resized > 1, there is no difference between pytorch_half_pixel and half_pixel
@ -48,18 +90,11 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
// if coordinate_transformation_mode is "pytorch_half_pixel",
// x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0
//
if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(1).dim_value() <= 1) {
if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(3).dim_value() <= 1) {
// we don't know the output H or W so we don't know if it will be compatible
length_resized_compatible_pytorch_half_pixel = false;
}
// Refer to onnxruntime/core/providers/cpu/tensor/upsamplebase.h,
size_t scale_idx = 2;
size_t size_idx = 3;
auto opset_version = node_unit.SinceVersion();
if (opset_version == 10) {
scale_idx = 1;
}
ProtoHelperNodeContext nc(node_unit.GetNode());
OpNodeProtoHelper info(&nc);
@ -78,6 +113,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
std::vector<int64_t> axes;
if (info.GetAttrs<int64_t>("axes", axes).IsOK() && axes.size() > 0) {
// TODO: We should be able to handle this if required
break;
}
@ -95,9 +131,10 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
// Coordinate transformation mode attr was introduced in version 11.
// before that asymmetric mode was the only available transformation mode
std::string coordinate_transform_mode_name =
opset_version > 10
? info.GetAttrOrDefault<std::string>("coordinate_transformation_mode", "half_pixel")
: "asymmetric";
opset_version > 10 ? info.GetAttrOrDefault<std::string>("coordinate_transformation_mode", "half_pixel")
: "asymmetric";
// TODO: Opset 19 added half_pixel_symmetric. Need to see if that can be supported.
if (coordinate_transform_mode_name != "asymmetric" &&
coordinate_transform_mode_name != "half_pixel" &&
@ -106,59 +143,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
break;
}
auto exclude_outside = info.GetAttrOrDefault<int64_t>("exclude_outside", 0) == 0 ? false : true;
if (exclude_outside) {
break;
}
// roi only takes effect when coordinate_transformation_mode is "tf_crop_and_resize"
// size or scales shouldnt't be provided in the same time but should at least be provided one of them
const auto* scale_tensor = inputs.size() >= scale_idx + 1
? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true)
: nullptr;
const auto* size_tensor = inputs.size() >= size_idx + 1
? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true)
: nullptr;
bool has_size = false;
bool has_scale = false;
InlinedVector<float> scale(4, 1.0F);
if (scale_tensor) {
const Initializer scale_val(*scale_tensor, node_unit.ModelPath());
auto scale_span = scale_val.DataAsSpan<float>();
if (scale_span.size() == 4) {
has_scale = true;
std::copy(scale_span.begin(), scale_span.end(), scale.begin());
}
}
if (size_tensor) {
auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*x_shape);
const Initializer size_val(*size_tensor, node_unit.ModelPath());
auto size_span = size_val.DataAsSpan<int64_t>();
if (size_span.size() == 4) {
has_size = true;
scale = {size_span[0] / static_cast<float>(input_shape[0]),
size_span[1] / static_cast<float>(input_shape[1]),
size_span[2] / static_cast<float>(input_shape[2]),
size_span[3] / static_cast<float>(input_shape[3])};
}
}
if ((has_size && has_scale) || (!has_size && !has_scale)) {
break;
}
if (scale[0] != 1.0F || (scale[1] != 1.0F && scale[3] != 1.0F)) {
break;
}
// only support xnn_create_resize_bilinear2d_nchw_f32
const bool is_NHWC = scale[3] == 1.0F;
if (!is_NHWC && (x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
if (info.GetAttrOrDefault<int64_t>("exclude_outside", 0) != 0) {
break;
}
@ -210,8 +195,7 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf
}
}
is_NHWC_ = scales_[3] == 1.0F;
int64_t channels = x_shape->dim(is_NHWC_ ? 3 : 1).dim_value();
int64_t channels = x_shape->dim(3).dim_value();
uint32_t flags = 0;
ORT_ENFORCE(mode_ == UpsampleMode::LINEAR, "only support bilinear resize");
@ -225,18 +209,16 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf
xnn_status xstatus = xnn_status_invalid_state;
struct xnn_operator* p = nullptr;
if (op_type_ == OpComputeType::op_compute_type_fp32) {
auto create_func = is_NHWC_ ? xnn_create_resize_bilinear2d_nhwc_f32 : xnn_create_resize_bilinear2d_nchw_f32;
xstatus = create_func(
channels, channels, channels, flags, &p);
xstatus = xnn_create_resize_bilinear2d_nhwc_f32(channels, channels, channels, flags, &p);
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
xstatus = xnn_create_resize_bilinear2d_nhwc_u8(
channels, channels, channels, flags, &p);
xstatus = xnn_create_resize_bilinear2d_nhwc_u8(channels, channels, channels, flags, &p);
} else {
xstatus = xnn_create_resize_bilinear2d_nhwc_s8(
channels, channels, channels, flags, &p);
xstatus = xnn_create_resize_bilinear2d_nhwc_s8(channels, channels, channels, flags, &p);
}
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_",
OpTypeToString(op_type_), " failed. Status:", xstatus);
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:",
xstatus);
op0_.reset(p);
}
@ -245,48 +227,56 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input,
const TensorShapeVector& output_dims) const {
const auto& X_shape = input->Shape();
auto N = X_shape[0];
auto H = is_NHWC_ ? X_shape[1] : X_shape[2];
auto W = is_NHWC_ ? X_shape[2] : X_shape[3];
auto H = X_shape[1];
auto W = X_shape[2];
Tensor* output = ctx->Output(0, TensorShape(output_dims));
pthreadpool_t t_pool = GetThreadPool();
xnn_status status = xnn_status_invalid_state;
if (op_type_ == OpComputeType::op_compute_type_fp32) {
auto oH = is_NHWC_ ? output_dims[1] : output_dims[2];
auto oW = is_NHWC_ ? output_dims[2] : output_dims[3];
auto setup_func = is_NHWC_ ? xnn_setup_resize_bilinear2d_nhwc_f32 : xnn_setup_resize_bilinear2d_nchw_f32;
status = setup_func(
op0_.get(),
N,
H, W, oH, oW,
input->Data<float>(),
output->MutableData<float>(),
t_pool);
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_resize_bilinear2d_nhwc_u8(
op0_.get(),
N,
H, W, output_dims[1], output_dims[2],
input->Data<uint8_t>(),
output->MutableData<uint8_t>(),
t_pool);
} else {
status = xnn_setup_resize_bilinear2d_nhwc_s8(
op0_.get(),
N,
H, W, output_dims[1], output_dims[2],
input->Data<int8_t>(),
output->MutableData<int8_t>(),
t_pool);
pthreadpool_t threadpool = GetThreadPool();
// setup allocator/automated dellocate for workspace
size_t workspace_size = 0;
size_t workspace_alignment = 0;
xnn_allocator* allocator = GetStoredAllocator().second;
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32;
if (op_type_ == OpComputeType::op_compute_type_qu8) {
reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8;
} else if (op_type_ == OpComputeType::op_compute_type_qs8) {
reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8;
}
auto status = reshape_fn(op0_.get(), N, H, W, output_dims[1], output_dims[2],
&workspace_size, &workspace_alignment, threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_),
" returned ", status);
}
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
if (op_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data<float>(),
output->MutableData<float>());
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data<uint8_t>(),
output->MutableData<uint8_t>());
} else {
status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data<int8_t>(),
output->MutableData<int8_t>());
}
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_resize_bilinear2d_nhwc_",
OpTypeToString(op_type_), " returned ", status);
}
status = xnn_run_operator(op0_.get(), t_pool);
status = xnn_run_operator(op0_.get(), threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
return Status::OK();
}
@ -315,29 +305,29 @@ Status Resize::Compute(OpKernelContext* ctx) const {
return ComputeInternal(ctx, X, output_shape);
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 10, 10, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 11, 12, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 13, 17, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 18, 18, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
Resize);
ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider,
ONNX_OPERATOR_KERNEL_EX(Resize, kMSInternalNHWCDomain, 19, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),

Просмотреть файл

@ -31,7 +31,6 @@ class Resize : public UpsampleBase, public XnnpackKernel {
const TensorShapeVector& output_dims) const;
private:
bool is_NHWC_;
XnnpackOperator op0_;
TensorShapeVector output_dims_;
OpComputeType op_type_ = OpComputeType::op_compute_type_invalid;

Просмотреть файл

@ -27,88 +27,117 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
return info;
}
#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, End, Op)>
#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op, Domain) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)>
#define KERNEL_CREATE_INFO(Start, Op) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, Op)>
#define KERNEL_CREATE_INFO(Start, Op, Domain) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)>
#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, type, Op)>
#define KERNEL_CREATE_INFO_TYPED(Start, Type, Op, Domain) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>
// Layout sensitive operators in NHWC domain
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider,
kDynamicDomainByCreate, 1, QLinearSoftmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
// ONNX operators
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 8, MatMul);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 10, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax);
// Internal domain
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list becoming empty after ops-reducing
KERNEL_CREATE_INFO(11, Conv),
KERNEL_CREATE_INFO(11, ConvTranspose),
KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose),
KERNEL_CREATE_INFO(1, QLinearConvTranspose),
KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool),
KERNEL_CREATE_INFO(12, MaxPool),
KERNEL_CREATE_INFO(11, AveragePool),
// layout sensitive. nodes will be moved to kMSInternalNHWCDomain by layout transformation
KERNEL_CREATE_INFO_VERSIONED(7, 9, AveragePool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(10, 10, AveragePool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(11, 18, AveragePool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(19, AveragePool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(1, 10, Conv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(11, Conv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(11, ConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(8, 9, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(10, 10, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(12, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(1, QLinearConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(10, 10, Resize, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(11, 12, Resize, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(13, 17, Resize, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED(18, 18, Resize, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(19, Resize, kMSInternalNHWCDomain),
// layout insensitive, use ONNX-domain directly
BuildKernelCreateInfo<
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul)>,
KERNEL_CREATE_INFO_VERSIONED(1, 10, Softmax, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(11, 12, Softmax, kOnnxDomain),
KERNEL_CREATE_INFO(13, Softmax, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(7, 8, Gemm, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(9, 10, Gemm, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(11, 12, Gemm, kOnnxDomain),
KERNEL_CREATE_INFO(13, Gemm, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(1, 8, MatMul, kOnnxDomain),
KERNEL_CREATE_INFO_VERSIONED(9, 12, MatMul, kOnnxDomain),
KERNEL_CREATE_INFO(13, MatMul, kOnnxDomain),
// quantization op
KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv),
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv),
KERNEL_CREATE_INFO(1, QLinearAveragePool),
BuildKernelCreateInfo<
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax)>,
KERNEL_CREATE_INFO(1, QLinearAveragePool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),
};
for (auto& function_table_entry : function_table) {

Просмотреть файл

@ -26,13 +26,15 @@ void xnn_deallocate(void* context, void* pointer) {
}
void* xnn_aligned_allocate(void* context, size_t alignment, size_t size) {
if (size == 0)
return nullptr;
#if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__)
ORT_ENFORCE(alignment <= 2 * sizeof(void*));
return xnn_allocate(context, size);
#else
void* ptr = xnn_allocate(context, size);
ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0,
" xnnpack wants to allocate a space with ", alignment, "bytes aligned. But it's not satisfied");
ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0, "xnnpack allocation was not aligned to ", alignment, " bytes.");
// if ptr is not aligned, we have to find a way to return a aligned ptr and store the original ptr
return ptr;
#endif

Просмотреть файл

@ -5,6 +5,47 @@ struct xnn_allocator;
namespace onnxruntime {
namespace xnnpack {
// copy #define logic from XNNPACK src/xnnpack/common.h to determine workspace alignment
#if defined(__APPLE__)
#include <TargetConditionals.h>
#endif
#if defined(__i386__) || defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(_M_IX86)
#define XNN_ARCH_X86 1
#else
#define XNN_ARCH_X86 0
#endif
#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) && !defined(_M_ARM64EC)
#define XNN_ARCH_X86_64 1
#else
#define XNN_ARCH_X86_64 0
#endif
#if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__)
#define XNN_ARCH_WASM 1
#else
#define XNN_ARCH_WASM 0
#endif
#if defined(__ANDROID__) || (defined(__APPLE__) && TARGET_OS_IPHONE)
#define XNN_PLATFORM_MOBILE 1
#else
#define XNN_PLATFORM_MOBILE 0
#endif
#if XNN_ARCH_WASM
#define XNN_ALLOCATION_ALIGNMENT 4
#elif XNN_ARCH_X86 || XNN_ARCH_X86_64
#if XNN_PLATFORM_MOBILE
#define XNN_ALLOCATION_ALIGNMENT 32
#else
#define XNN_ALLOCATION_ALIGNMENT 64
#endif
#else
#define XNN_ALLOCATION_ALIGNMENT 16
#endif
std::pair<AllocatorPtr&, xnn_allocator*> GetStoredAllocator();
} // namespace xnnpack

Просмотреть файл

@ -4,6 +4,7 @@
#pragma once
#include "core/framework/op_kernel.h"
#include "core/providers/xnnpack/xnnpack_execution_provider.h"
#include "xnnpack.h"
struct pthreadpool;
@ -12,18 +13,59 @@ namespace xnnpack {
class XnnpackKernel : public OpKernel {
public:
explicit XnnpackKernel(const OpKernelInfo& info)
: OpKernel(info),
xnnpack_threadpool_(
static_cast<const XnnpackExecutionProvider*>(info.GetExecutionProvider())
->GetPrivateThreadPool()) {
explicit XnnpackKernel(const OpKernelInfo& info, bool enable_caches = false)
: OpKernel{info},
xnnpack_threadpool_{
static_cast<const XnnpackExecutionProvider*>(info.GetExecutionProvider())->GetPrivateThreadPool()},
caches_{enable_caches} {
}
[[nodiscard]] pthreadpool* GetThreadPool() const {
return xnnpack_threadpool_;
}
// see comment below about enabling code cache
// xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();}
xnn_code_cache_t GetCodeCache() { return nullptr; }
xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); }
private:
pthreadpool* xnnpack_threadpool_;
// Helper class to wrap usage of the XNNPACK weights and code caches.
// NOTE: Currently creating/freeing the code cache is not exposed via the public xnnpack.h header so usage is
// commented out. If we need to use it, we'll need to add the 'src' directory of XNNPACK to the include path
// and #include "xnnpack/cache.h"
struct Caches {
Caches(bool enable)
: // auto_code_cache(nullptr, xnn_release_code_cache),
auto_weights_cache(nullptr, xnn_delete_weights_cache) {
if (enable) {
#ifdef XNN_CACHE_ENABLE
xnn_status status = xnn_status_success;
#if XNN_PLATFORM_JIT
// status = xnn_init_code_cache(&code_cache_);
// ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");)
// auto_code_cache.reset(&code_cache_);
#endif
// status = xnn_init_weights_cache(&weights_cache_);
xnn_weights_cache_t weights_cache = nullptr;
status = xnn_create_weights_cache(&weights_cache, 0);
ORT_ENFORCE(status == xnn_status_success, "Failed to create XNNPACK weights cache");
auto_weights_cache.reset(weights_cache);
#endif
}
}
// std::unique_ptr<xnn_code_cache, decltype(&xnn_release_code_cache)> auto_code_cache;
std::unique_ptr<xnn_weights_cache, decltype(&xnn_delete_weights_cache)> auto_weights_cache;
// private:
// #if defined(XNN_CACHE_ENABLE) && XNN_PLATFORM_JIT
// xnn_code_cache code_cache_;
// #endif
};
Caches caches_;
};
} // namespace xnnpack
} // namespace onnxruntime

Просмотреть файл

@ -399,6 +399,8 @@ bool SetEpsForAllNodes(Graph& graph,
const std::vector<std::unique_ptr<IExecutionProvider>>& execution_providers,
const std::vector<std::shared_ptr<CustomRegistry>>* custom_registries) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelRegistry::TypeConstraintMap type_constraint_map{};
for (auto& node : graph.Nodes()) {
if (node.OpType() == kConstant)
continue;
@ -426,13 +428,28 @@ bool SetEpsForAllNodes(Graph& graph,
break;
}
// check the internal NHWC domain if EP requests NHWC as it may only have a kernel registered in that domain
if (ep->GetPreferredLayout() == DataLayout::NHWC) {
const KernelCreateInfo* kci = nullptr;
auto status = ep->GetKernelRegistry()->TryFindKernel(ep->Type(),
std::string_view(node.OpType()),
std::string_view(kMSInternalNHWCDomain),
node.SinceVersion(),
type_constraint_map,
&kci);
if (status.IsOK() && kci != nullptr) {
found = true;
break;
}
}
// Check the EP has an impl for the node from custom_registries
if (custom_registries != nullptr &&
std::any_of(custom_registries->cbegin(), custom_registries->cend(),
[&](auto reg) { return KernelRegistry::HasImplementationOf(
*reg->GetKernelRegistry(),
node, ep->Type(),
kernel_type_str_resolver); })) {
[&](auto reg) {
return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(),
kernel_type_str_resolver);
})) {
found = true;
break;
}
@ -760,7 +777,7 @@ void BaseTester::ExecuteModelForEps(
for (const auto& ep : execution_providers) {
providers.append(ep->Type() + " ");
}
LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << "were unable to run the model.";
LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << " were unable to run the model.";
return;
}

Просмотреть файл

@ -21,7 +21,7 @@ struct ConvOp {
std::unique_ptr<CompareOpTester> get_test() {
RandomValueGenerator random{};
auto test = std::make_unique<CompareOpTester>("Conv", 7);
auto test = std::make_unique<CompareOpTester>("Conv", 11); // internal NHWC domain starts at opset 11
std::vector<T> input_data = random.Uniform<T>(input_dims, 0.0f, 1.0f);
std::vector<int64_t> weight_dims{channels, input_dims[1] / group, kernel_shape[0], kernel_shape[1]};

Просмотреть файл

@ -203,7 +203,8 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {
}
TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model.onnx";
// the internal NHWC domain supports opset 11 and later
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx";
SessionOptions so;
// set this if you want to manually inspect the optimized model

Двоичные данные
onnxruntime/test/testdata/squeezenet/model_opset11.onnx поставляемый Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -11,7 +11,7 @@ steps:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
version: 1.0.104
version: 1.0.107
downloadPath: $(Build.BinariesDirectory)/deps
# The private ADO project
@ -22,7 +22,7 @@ steps:
packageType: upack
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
version: 1.0.104
version: 1.0.107
downloadPath: $(Build.BinariesDirectory)/deps
# You can add more ADO accounts at here.