Update XNNPACK to latest version (#18038)
### Description <!-- Describe your changes. --> Update XNNPACK to latest version - adds fp16 kernels and various other improvements - requires pthreadpool update as well Most code updates in the XNNPACK EP are to adjust to the new XNNPACK API - 'setup' is split into 'reshape' and 'setup' - some ops use a workspace buffer - copied workspace allocation from XNNPACK unit test code - some suffixes changed Added wrapper for XNNPACK caches to base XNNPACK EP kernel - simplifies usage - XNNPACK split out the code and weights caches, but the code cache isn't currently usable via the public API - we could use the internal types if we think it's required for performance reasons. non-trivial though as we'd need to propagate ifdef values from the XNNPACK build up to the ORT build. - using XNNPACK internals would also mean we would not be able to support using a pre-build XNNPACK package - not an issue currently Fixed opset registration for internal NHWC domain - was not being tied to the ONNX version, so nodes inserted by layout transformation had the incorrect opset - a number of other places needed updating once this issue was fixed Remove support for NCHW Resize from XNNPACK EP so it's NHWC only - we only supported NCHW for fp32, - doing so adds complexity in multiple places (XNNPACK EP kernel implementation, layout transformation and transpose optimization) - unclear if that complexity provides any benefit. can add back if required by production scenario ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> We're looking at enabling fp16 support for CoreML and NNAPI. If we do that we need a good fallback story if the CPU EP will be used. The XNNPACK fp16 kernels will hopefully provide that. NOTE: This PR doesn't add fp16 support to the XNNPACK EP kernels. That can be done as required in separate EPs and should be relatively simple to do.
This commit is contained in:
Родитель
e36d003765
Коммит
4f2096be38
|
@ -3,10 +3,15 @@
|
|||
#The columns are separated by ";" because a list in cmake is just a ";" separated group of strings.
|
||||
#Names should be in lower case. They will be used as variable names in cmake.
|
||||
#URLs can be either https URLs or local file paths in cmake-style(directory separator is a forward slash character).
|
||||
#SHA1 hashes can be generated by running sha1sum command.
|
||||
#SHA1 hashes can be generated by running sha1sum command on linux. PowerShell can also be used:
|
||||
# (Get-FileHash -Algorithm SHA1 <filename>).Hash.ToLower()
|
||||
#If you need to change abseil's version to a different one, you may also want to update external\abseil-cpp.natvis
|
||||
#since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would
|
||||
#not affect built binaries.
|
||||
#
|
||||
# NOTE: You must run deps_update_and_upload.py when ready to test your changes in a CI.
|
||||
# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29
|
||||
#
|
||||
abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91
|
||||
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
|
||||
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
|
||||
|
@ -18,7 +23,7 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34
|
|||
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
|
||||
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
|
||||
googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc
|
||||
googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c
|
||||
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
|
||||
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
|
||||
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
|
||||
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
|
||||
|
@ -35,7 +40,7 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v
|
|||
protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617
|
||||
protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef
|
||||
psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013
|
||||
pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/1787867f6183f056420e532eec640cba25efafea.zip;e43e80781560c5ab404a4da20f34d846f5f5d101
|
||||
pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip;07a0aa91dd9bf86f31b95497e00f31d8a261a4bd
|
||||
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13
|
||||
pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip;85da3caa60eb2b148613b443fbc2bfdc30689965
|
||||
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
|
||||
|
|
|
@ -25,17 +25,23 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR})
|
|||
|
||||
FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool})
|
||||
onnxruntime_fetchcontent_makeavailable(pthreadpool)
|
||||
FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack}
|
||||
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch)
|
||||
|
||||
FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack}
|
||||
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch
|
||||
)
|
||||
onnxruntime_fetchcontent_makeavailable(googlexnnpack)
|
||||
set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR})
|
||||
set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include)
|
||||
|
||||
set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool)
|
||||
|
||||
|
||||
# the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
# See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels
|
||||
message("Adding WebAssembly Source Files to XNNPACK")
|
||||
set(wasm_srcs "")
|
||||
|
||||
file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config)
|
||||
|
||||
# Replace newlines with semicolon so that it is treated as a list by CMake
|
||||
|
@ -70,25 +76,23 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
|||
set(${target_srcs} ${bazel_srcs} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
GetSrcListFromBazel("PROD_SCALAR_WASM_MICROKERNEL_SRCS" prod_scalar_wasm_srcs)
|
||||
GetSrcListFromBazel("ALL_WASM_MICROKERNEL_SRCS" all_wasm_srcs)
|
||||
GetSrcListFromBazel("WASM32_ASM_MICROKERNEL_SRCS" wasm32_asm_srcs)
|
||||
GetSrcListFromBazel("OPERATOR_SRCS" operator_srcs)
|
||||
GetSrcListFromBazel("TABLE_SRCS" table_srcs)
|
||||
list(APPEND wasm_srcs ${operator_srcs} ${table_srcs})
|
||||
|
||||
message(DEBUG "prod_scalar_wasm_srcs: ${prod_scalar_wasm_srcs}\n")
|
||||
message(DEBUG "all_wasm_srcs: ${all_wasm_srcs}\n")
|
||||
message(DEBUG "wasm32_asm_srcs: ${wasm32_asm_srcs}\n")
|
||||
|
||||
message("Adding WebAssembly Source Files to XNNPACK")
|
||||
set(wasm_srcs "")
|
||||
list(APPEND wasm_srcs ${prod_scalar_wasm_srcs})
|
||||
list(APPEND wasm_srcs ${all_wasm_srcs})
|
||||
list(APPEND wasm_srcs ${wasm32_asm_srcs})
|
||||
|
||||
target_sources(XNNPACK PRIVATE ${wasm_srcs})
|
||||
# kernels
|
||||
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c)
|
||||
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c)
|
||||
|
||||
if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
|
||||
GetSrcListFromBazel("ALL_WASMSIMD_MICROKERNEL_SRCS" all_wasmsimd_srcs)
|
||||
message(DEBUG "all_wasmsimd_srcs: ${all_wasmsimd_srcs}")
|
||||
target_sources(XNNPACK PRIVATE ${all_wasmsimd_srcs})
|
||||
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
|
||||
target_compile_options(XNNPACK PRIVATE "-msimd128")
|
||||
endif()
|
||||
|
||||
message(DEBUG "wasm_srcs: ${wasm_srcs}\n")
|
||||
target_sources(XNNPACK PRIVATE ${wasm_srcs})
|
||||
|
||||
# add flags from BAZEL.build
|
||||
target_compile_options(XNNPACK PRIVATE "-fno-fast-math")
|
||||
target_compile_options(XNNPACK PRIVATE "-fno-math-errno")
|
||||
endif()
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})
|
||||
onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack
|
||||
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface
|
||||
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool
|
||||
flatbuffers::flatbuffers Boost::mp11 safeint_interface
|
||||
)
|
||||
|
||||
add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
@ -35,4 +36,4 @@
|
|||
# there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline'
|
||||
if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -41,7 +41,7 @@ function(AddTest)
|
|||
if (MSVC)
|
||||
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd6330>"
|
||||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6330>")
|
||||
#Abseil has a lot of C4127/C4324 warnings.
|
||||
#Abseil has a lot of C4127/C4324 warnings.
|
||||
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4127>"
|
||||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4127>")
|
||||
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4324>"
|
||||
|
@ -201,8 +201,18 @@ function(AddTest)
|
|||
list(APPEND TEST_NODE_FLAGS "--experimental-wasm-simd")
|
||||
endif()
|
||||
|
||||
# prefer Node from emsdk so the version is more deterministic
|
||||
if (DEFINED ENV{EMSDK_NODE})
|
||||
set(NODE_EXECUTABLE $ENV{EMSDK_NODE})
|
||||
else()
|
||||
# warning as we don't know what node version is being used and whether things like the TEST_NODE_FLAGS
|
||||
# will be valid. e.g. "--experimental-wasm-simd" is not valid with node v20 or later.
|
||||
message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.")
|
||||
set(NODE_EXECUTABLE node)
|
||||
endif()
|
||||
|
||||
add_test(NAME ${_UT_TARGET}
|
||||
COMMAND node ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS}
|
||||
COMMAND ${NODE_EXECUTABLE} ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS}
|
||||
WORKING_DIRECTORY $<TARGET_FILE_DIR:${_UT_TARGET}>
|
||||
)
|
||||
endif()
|
||||
|
|
|
@ -192,8 +192,13 @@ else()
|
|||
onnxruntime_util
|
||||
re2::re2
|
||||
)
|
||||
|
||||
set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")
|
||||
|
||||
if (onnxruntime_USE_XNNPACK)
|
||||
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
|
||||
string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'")
|
||||
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ALLOW_TABLE_GROWTH=1")
|
||||
endif()
|
||||
|
||||
if(onnxruntime_USE_WEBNN)
|
||||
|
@ -204,7 +209,6 @@ else()
|
|||
target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard)
|
||||
endif()
|
||||
|
||||
set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']")
|
||||
if (onnxruntime_USE_JSEP)
|
||||
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
|
||||
else()
|
||||
|
@ -212,7 +216,7 @@ else()
|
|||
endif()
|
||||
|
||||
target_link_options(onnxruntime_webassembly PRIVATE
|
||||
"SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}"
|
||||
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
|
||||
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
|
||||
"SHELL:-s MAXIMUM_MEMORY=4294967296"
|
||||
"SHELL:-s EXIT_RUNTIME=0"
|
||||
|
|
|
@ -1,66 +1,27 @@
|
|||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index d53c48aa1..77c3cf983 100755
|
||||
index dba9b4687..bcaa18ad7 100755
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -105,22 +105,12 @@ ENDIF()
|
||||
|
||||
@@ -122,7 +122,7 @@ ENDIF()
|
||||
# ---[ Build flags
|
||||
IF(NOT CMAKE_SYSTEM_NAME)
|
||||
MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined")
|
||||
-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS)$")
|
||||
+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS|Emscripten|iOS)$")
|
||||
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}")
|
||||
-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT)$")
|
||||
+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT|Emscripten|iOS)$")
|
||||
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"")
|
||||
ENDIF()
|
||||
|
||||
# ---[ Download deps
|
||||
IF(NOT XNNPACK_USE_SYSTEM_LIBS)
|
||||
- IF(NOT DEFINED CLOG_SOURCE_DIR)
|
||||
- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)")
|
||||
- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt")
|
||||
- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
|
||||
- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download")
|
||||
- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build .
|
||||
- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download")
|
||||
- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory")
|
||||
- ENDIF()
|
||||
-
|
||||
IF(NOT DEFINED CPUINFO_SOURCE_DIR)
|
||||
MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)")
|
||||
CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt")
|
||||
@@ -7108,6 +7098,10 @@ IF(MSVC)
|
||||
SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O2 >")
|
||||
SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O2 >")
|
||||
SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: /O1 >")
|
||||
+ELSEIF(CMAKE_GENERATOR STREQUAL Xcode)
|
||||
+ TARGET_COMPILE_OPTIONS(all_microkernels PRIVATE $<$<NOT:$<CONFIG:Debug>>: -O2 >)
|
||||
+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$<NOT:$<CONFIG:Debug>>: -O2 >)
|
||||
+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$<NOT:$<CONFIG:Debug>>: -Os >)
|
||||
ELSE()
|
||||
SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: -O2 >")
|
||||
SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$<NOT:$<CONFIG:Debug>>: -O2 >")
|
||||
@@ -7142,26 +7136,6 @@ IF(LIBM)
|
||||
TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM})
|
||||
IF(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY)
|
||||
TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging)
|
||||
TARGET_LINK_LIBRARIES(post-operation PRIVATE logging)
|
||||
TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run)
|
||||
- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
|
||||
+ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake
|
||||
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph)
|
||||
+ ELSE()
|
||||
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
|
||||
+ ENDIF()
|
||||
SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES)
|
||||
ENDIF()
|
||||
|
||||
-# ---[ Configure clog
|
||||
-IF(NOT TARGET clog)
|
||||
- IF(NOT XNNPACK_USE_SYSTEM_LIBS)
|
||||
- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "")
|
||||
- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "")
|
||||
- ADD_SUBDIRECTORY(
|
||||
- "${CLOG_SOURCE_DIR}/deps/clog"
|
||||
- "${CMAKE_BINARY_DIR}/clog")
|
||||
- # We build static version of clog but a dynamic library may indirectly depend on it
|
||||
- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
- ELSE()
|
||||
- ADD_LIBRARY(clog STATIC IMPORTED)
|
||||
- FIND_LIBRARY(CLOG_LIBRARY clog)
|
||||
- IF(NOT CLOG_LIBRARY)
|
||||
- MESSAGE(FATAL_ERROR "Cannot find clog")
|
||||
- ENDIF()
|
||||
- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}")
|
||||
- ENDIF()
|
||||
-ENDIF()
|
||||
-
|
||||
# ---[ Configure cpuinfo
|
||||
IF(NOT TARGET cpuinfo)
|
||||
IF(NOT XNNPACK_USE_SYSTEM_LIBS)
|
||||
IF(NOT MSVC)
|
||||
|
|
|
@ -20,15 +20,15 @@ Do not modify directly.*
|
|||
| Asinh | ai.onnx(9+) | |
|
||||
| Atan | ai.onnx(7+) | |
|
||||
| Atanh | ai.onnx(9+) | |
|
||||
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation |
|
||||
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
|
||||
| BiasAdd | com.microsoft(1+) | |
|
||||
| BiasSplitGelu | com.microsoft(1+) | |
|
||||
| Cast | ai.onnx(6-8,9-12,13-18,19+) | |
|
||||
| Ceil | ai.onnx(6-12,13+) | |
|
||||
| Clip | ai.onnx(6-10,11,12,13+) | |
|
||||
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
|
||||
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation |
|
||||
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
|
||||
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; conv3d is not supported; need implementing activation |
|
||||
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
|
||||
| Cos | ai.onnx(7+) | |
|
||||
| Cosh | ai.onnx(9+) | |
|
||||
| Div | ai.onnx(7-12,13,14+) | |
|
||||
|
@ -57,7 +57,7 @@ Do not modify directly.*
|
|||
| LessOrEqual | ai.onnx(12-15,16+) | |
|
||||
| Log | ai.onnx(6-12,13+) | |
|
||||
| MatMul | ai.onnx(1-12,13+) | |
|
||||
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
|
||||
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
|
||||
| MemcpyFromHost | ai.onnx(1+) | |
|
||||
| MemcpyToHost | ai.onnx(1+) | |
|
||||
| Mul | ai.onnx(7-12,13,14+) | |
|
||||
|
@ -79,7 +79,7 @@ Do not modify directly.*
|
|||
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
| Relu | ai.onnx(6-12,13,14+) | |
|
||||
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
|
||||
| Sigmoid | ai.onnx(6-12,13+) | |
|
||||
| Sin | ai.onnx(7+) | |
|
||||
|
|
|
@ -164,7 +164,10 @@ async function initializeSession(
|
|||
session = await ort.InferenceSession.create(modelFilePath, sessionConfig);
|
||||
}
|
||||
} catch (e) {
|
||||
Logger.error('TestRunner', `Failed to load model from file: ${modelFilePath}. Error: ${inspect(e)}`);
|
||||
Logger.error(
|
||||
'TestRunner',
|
||||
`Failed to load model from file: ${modelFilePath}. ` +
|
||||
`Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`);
|
||||
throw e;
|
||||
}
|
||||
|
||||
|
|
|
@ -62,8 +62,13 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
|
|||
|
||||
auto create_error_message = [&node, &status](const std::string& prefix) {
|
||||
std::ostringstream errormsg;
|
||||
errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")";
|
||||
errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). ";
|
||||
errormsg << prefix;
|
||||
const auto& domain = node.Domain();
|
||||
if (!domain.empty()) {
|
||||
errormsg << domain << ".";
|
||||
}
|
||||
errormsg << node.OpType() << "(" << node.SinceVersion() << ")"
|
||||
<< " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). ";
|
||||
if (!status.IsOK())
|
||||
errormsg << status.ErrorMessage();
|
||||
|
||||
|
|
|
@ -90,12 +90,16 @@ void RegisterNHWCSchemaWithActivation(const RegistrationFunc& f, ::ONNX_NAMESPAC
|
|||
void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMESPACE::OpSchema&&)>& fn) {
|
||||
// if the operator may be fused with an activation, use the WITH_ACTIVATION variant to add optional attributes
|
||||
// for the activation parameters.
|
||||
// For now we only register operators from opset 11 on. Models can easily have their opset updated using ONNX tools
|
||||
// We mainly register operators from opset 11 on . Models can easily have their opset updated using ONNX tools
|
||||
// so supporting older opsets is unnecessary.
|
||||
// Older opsets are included on a per-operator basis as needed.
|
||||
|
||||
// NOTE: This should be in sync with GetLayoutSensitiveOps in
|
||||
// /onnxruntime/core/optimizer/transpose_optimization/transpose_optimizer.cc
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 7);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 10);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 11);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 19);
|
||||
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 9);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14);
|
||||
|
@ -106,16 +110,18 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
|
|||
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, InstanceNormalization, 6);
|
||||
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, Conv, 1);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, Conv, 11);
|
||||
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 11);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 1);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, ConvTranspose, 11);
|
||||
|
||||
REGISTER_NHWC_SCHEMA(fn, GlobalAveragePool, 1);
|
||||
REGISTER_NHWC_SCHEMA(fn, GlobalLpPool, 2);
|
||||
REGISTER_NHWC_SCHEMA(fn, GlobalMaxPool, 1);
|
||||
|
||||
REGISTER_NHWC_SCHEMA(fn, GridSample, 16);
|
||||
REGISTER_NHWC_SCHEMA(fn, GridSample, 20);
|
||||
|
||||
REGISTER_NHWC_SCHEMA(fn, LRN, 1);
|
||||
REGISTER_NHWC_SCHEMA(fn, LRN, 13);
|
||||
|
@ -123,6 +129,9 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
|
|||
REGISTER_NHWC_SCHEMA(fn, LpPool, 11);
|
||||
REGISTER_NHWC_SCHEMA(fn, LpPool, 18);
|
||||
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 1);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 8);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 10);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 11);
|
||||
REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, MaxPool, 12);
|
||||
|
||||
|
|
|
@ -3908,7 +3908,8 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std
|
|||
// kernel lookup works as per usual, if not using an existing schema.
|
||||
if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) {
|
||||
ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node),
|
||||
"Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType());
|
||||
"Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType(),
|
||||
" SinceVersion:", fused_node.SinceVersion());
|
||||
} else if (IndexedSubGraph::SourceOfSchema::REUSE_OR_CREATE == sub_graph.schema_source) {
|
||||
auto schema_key = GenerateSchemaKey(sub_graph);
|
||||
if (reusable_fused_schema_map_.count(schema_key) == 0) {
|
||||
|
|
|
@ -232,6 +232,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
|
|||
}
|
||||
}
|
||||
|
||||
// special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported
|
||||
if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) {
|
||||
auto onnx_version = domain_to_version.find(kOnnxDomain);
|
||||
if (onnx_version != domain_to_version.end()) {
|
||||
domain_to_version[kMSInternalNHWCDomain] = onnx_version->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto domain_map = allow_official_onnx_release_only_final
|
||||
? schema_registry->GetLastReleasedOpsetVersions(false)
|
||||
: schema_registry->GetLatestOpsetVersions(false);
|
||||
|
|
|
@ -66,17 +66,6 @@ bool ConvertNodeLayout(const api::NodeRef& node) {
|
|||
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
|
||||
|
||||
// handle special cases
|
||||
#if defined(USE_XNNPACK)
|
||||
if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) {
|
||||
if (node.OpType() == "Resize") {
|
||||
// XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize
|
||||
// with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through
|
||||
// the Resize during transpose optimization.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_JSEP)
|
||||
// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
|
||||
if (node.GetExecutionProviderType() == kJsExecutionProvider) {
|
||||
|
|
|
@ -17,7 +17,7 @@ static bool EPAwareHandleResize(HandlerArgs& args) {
|
|||
// layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled
|
||||
// by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs.
|
||||
const auto ep_type = args.node.GetExecutionProviderType();
|
||||
if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) {
|
||||
if (ep_type == kCpuExecutionProvider) {
|
||||
// allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model
|
||||
int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
|
||||
if (rank_int == 4) {
|
||||
|
|
|
@ -352,7 +352,7 @@ class UpsampleBase {
|
|||
(scales.size() == 4 && scales[0] == 1 && scales[3] == 1) ||
|
||||
scales.size() == 3 ||
|
||||
(scales.size() == 5 && scales[0] == 1 && scales[1] == 1),
|
||||
"'Linear' mode only support:\n"
|
||||
"'Linear' mode only supports:\n"
|
||||
" * 2-D inputs or\n"
|
||||
" * 3-D inputs ('Bilinear', 'Trilinear') or\n"
|
||||
" * 4-D inputs with the corresponding outermost 2 scale values being 1"
|
||||
|
|
|
@ -238,18 +238,40 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Whe
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm);
|
||||
|
@ -257,17 +279,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax);
|
||||
|
@ -291,11 +302,17 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
|
||||
|
@ -304,11 +321,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
|
||||
|
@ -322,8 +334,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, LayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, InstanceNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range);
|
||||
|
||||
|
@ -508,18 +521,40 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
|
||||
|
@ -527,17 +562,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
|
||||
|
@ -575,7 +599,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize)>,
|
||||
|
@ -594,8 +618,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, InstanceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
|
||||
|
|
|
@ -25,6 +25,13 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Conv<false>);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Conv,
|
||||
kMSInternalNHWCDomain,
|
||||
1, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Conv<true>);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Conv,
|
||||
kOnnxDomain,
|
||||
|
|
|
@ -24,6 +24,13 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
ConvTranspose<false>);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
ConvTranspose,
|
||||
kMSInternalNHWCDomain,
|
||||
1, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
ConvTranspose<true>);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
ConvTranspose,
|
||||
kOnnxDomain,
|
||||
|
|
|
@ -52,15 +52,20 @@ namespace js {
|
|||
Pool<pool_type, is_channels_last>);
|
||||
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9)
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9)
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10)
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10)
|
||||
POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11)
|
||||
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11)
|
||||
POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1)
|
||||
POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1)
|
||||
|
||||
POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7)
|
||||
POOLING_KERNEL_VERSIONED(MaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1, 7)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 8, 9)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 10, 10)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11)
|
||||
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11)
|
||||
POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12)
|
||||
|
|
|
@ -51,6 +51,7 @@ namespace js {
|
|||
REGISTER_RESIZE_KERNEL(domain, 19);
|
||||
|
||||
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kOnnxDomain);
|
||||
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kMSInternalNHWCDomain);
|
||||
REGISTER_RESIZE_KERNEL_DOMAIN(kOnnxDomain);
|
||||
REGISTER_RESIZE_KERNEL_DOMAIN(kMSInternalNHWCDomain);
|
||||
|
||||
|
|
|
@ -7,22 +7,22 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_node_proto_helper.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/nn/pool_attributes.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
|
||||
// each operator provides a helper to check if supported
|
||||
#include "core/providers/xnnpack/math/gemm.h"
|
||||
#include "core/providers/xnnpack/math/matmul.h"
|
||||
#include "core/providers/xnnpack/math/softmax.h"
|
||||
#include "core/providers/xnnpack/nn/average_pool.h"
|
||||
#include "core/providers/xnnpack/nn/conv.h"
|
||||
#include "core/providers/xnnpack/nn/conv_transpose.h"
|
||||
#include "core/providers/xnnpack/nn/max_pool.h"
|
||||
#include "core/providers/xnnpack/math/gemm.h"
|
||||
#include "core/providers/xnnpack/math/matmul.h"
|
||||
#include "core/providers/xnnpack/nn/average_pool.h"
|
||||
#include "core/providers/xnnpack/nn/resize.h"
|
||||
#include "core/providers/xnnpack/nn/softmax.h"
|
||||
#include "core/providers/xnnpack/tensor/resize.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace xnnpack {
|
||||
|
|
|
@ -25,7 +25,7 @@ const char* OpTypeToString(OpComputeType opCtype) {
|
|||
case op_compute_type_fp16:
|
||||
return "fp16";
|
||||
case op_compute_type_qs8_per_channel:
|
||||
return "qc8";
|
||||
return "qs8_qc8w";
|
||||
case op_compute_type_qs8:
|
||||
return "qs8";
|
||||
case op_compute_type_qu8:
|
||||
|
|
|
@ -78,7 +78,7 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra
|
|||
return supported;
|
||||
}
|
||||
|
||||
Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info) {
|
||||
Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*enable_caches*/ true) {
|
||||
const auto& node{Node()};
|
||||
|
||||
info.GetAttrOrDefault<float>("alpha", &alpha_, 1.f);
|
||||
|
@ -146,14 +146,9 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr,
|
|||
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride,
|
||||
B_->Data<float>(), // const float* kernel,
|
||||
bias_Data, // const float* bias,
|
||||
output_min,
|
||||
output_max,
|
||||
output_min, output_max,
|
||||
flags,
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
&xnn_caches_,
|
||||
#else
|
||||
0,
|
||||
#endif
|
||||
GetCodeCache(), GetWeightsCache(),
|
||||
&p);
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
|
@ -165,20 +160,25 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr,
|
|||
}
|
||||
|
||||
Status Gemm::Compute(OpKernelContext* context) const {
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
const auto* A = context->Input<Tensor>(0);
|
||||
auto Y = context->Output(0, {M_, N_});
|
||||
|
||||
// if input is empty tensor, return as nothing need to be calculated and we've set the shape for the output
|
||||
if (M_ == 0 || N_ == 0)
|
||||
if (M_ == 0 || N_ == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xnn_status status = xnn_setup_fully_connected_nc_f32(
|
||||
op0_.get(),
|
||||
trans_A_ == CblasNoTrans ? M_ : K_, // Number of rows to multiply
|
||||
A->Data<float>(),
|
||||
Y->MutableData<float>(),
|
||||
t_pool);
|
||||
xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(),
|
||||
// Number of rows to multiply
|
||||
trans_A_ == CblasNoTrans ? M_ : K_,
|
||||
threadpool);
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data<float>(), Y->MutableData<float>());
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
|
||||
|
@ -192,7 +192,15 @@ Status Gemm::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 12, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 8, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 9, 10, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm);
|
||||
|
||||
|
|
|
@ -41,14 +41,6 @@ class Gemm : protected GemmBase, public XnnpackKernel {
|
|||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
#if XNN_PLATFORM_JIT
|
||||
xnn_code_cache code_cache_;
|
||||
#endif
|
||||
xnn_caches xnn_caches_ = {0, 0};
|
||||
xnn_weights_cache weights_cache_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace xnnpack
|
||||
|
|
|
@ -62,7 +62,7 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g
|
|||
return supported;
|
||||
}
|
||||
|
||||
MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info) {}
|
||||
MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ true) {}
|
||||
|
||||
Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
/*out*/ bool& is_packed,
|
||||
|
@ -99,9 +99,11 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
|||
output_max,
|
||||
flags,
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
&xnn_caches_,
|
||||
GetCodeCache(),
|
||||
GetWeightsCache(),
|
||||
#else
|
||||
0,
|
||||
nullptr,
|
||||
nullptr,
|
||||
#endif
|
||||
&p);
|
||||
|
||||
|
@ -116,7 +118,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
|||
|
||||
Status MatMul::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* a = ctx->Input<Tensor>(0);
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_));
|
||||
Tensor* y = ctx->Output(0, helper.OutputShape());
|
||||
|
@ -126,13 +128,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
auto* y_data = y->MutableData<float>();
|
||||
|
||||
xnn_status status = xnn_setup_fully_connected_nc_f32(
|
||||
op0_.get(),
|
||||
a->Shape()[0],
|
||||
a->Data<float>(),
|
||||
y_data,
|
||||
t_pool);
|
||||
xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data<float>(), y_data);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
|
||||
}
|
||||
|
@ -144,7 +145,11 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 8, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
MatMul);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 9, 12, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
MatMul);
|
||||
|
||||
|
|
|
@ -32,14 +32,6 @@ class MatMul : public XnnpackKernel {
|
|||
AllocatorPtr myAlloc;
|
||||
|
||||
XnnpackOperator op0_ = nullptr;
|
||||
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
#if XNN_PLATFORM_JIT
|
||||
xnn_code_cache code_cache_;
|
||||
#endif
|
||||
xnn_caches xnn_caches_ = {0, 0};
|
||||
xnn_weights_cache weights_cache_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace xnnpack
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/xnnpack/nn/softmax.h"
|
||||
#include "core/providers/xnnpack/math/softmax.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
|
@ -25,6 +25,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph
|
|||
output_type != TensorTypeUint8) {
|
||||
break;
|
||||
}
|
||||
|
||||
// to ensure its output scale and zp are 1/256 and 0, otherwise xnnpack EP has to do extra requantization
|
||||
// idealy, QlinearSoftmax or QDQSoftmax will keep this output scale and zp, but we have to handle some
|
||||
// qdq models converted from other framework
|
||||
|
@ -33,6 +34,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph
|
|||
if (fabs(q_scale.DataAsSpan<float>()[0] - 1.0f / 256.0f) > 0.0001f) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (zero_tensor) {
|
||||
Initializer q_zp(*zero_tensor, node_unit.ModelPath());
|
||||
if (q_zp.DataAsSpan<uint8_t>()[0] != 0) {
|
||||
|
@ -57,6 +59,7 @@ bool Softmax::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
IsQuantSoftmaxSupported(node_unit, graph) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// use do {} while(false) so it's easier to set a breakpoint on the return
|
||||
do {
|
||||
// SoftMax has 1 input.
|
||||
|
@ -133,6 +136,7 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} {
|
|||
ORT_ENFORCE(status.IsOK(), "opset must be existed in attributes of QlinearSoftmax");
|
||||
opset_ = gsl::narrow_cast<int>(opset);
|
||||
}
|
||||
|
||||
int64_t axis = -1;
|
||||
Status status = info.GetAttr<int64_t>("axis", &axis);
|
||||
// our op checker function has ensured that axis must be the last dim
|
||||
|
@ -162,23 +166,22 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} {
|
|||
if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
// the order of input tensor, x,x_scale, x_zp, y_scale, y_zp
|
||||
OpQuantParam quant_param = ParseQuantParamForOp(info, x_dtype, 1);
|
||||
xstatus = xnn_create_softmax_nc_qu8(
|
||||
channels,
|
||||
channels,
|
||||
channels,
|
||||
quant_param[0].first[0], // x_scale
|
||||
quant_param[1].second, // y_zp
|
||||
quant_param[1].first[0], // y_scale
|
||||
0, // flags,
|
||||
&p);
|
||||
xstatus = xnn_create_softmax_nc_qu8(channels,
|
||||
channels,
|
||||
channels,
|
||||
quant_param[0].first[0], // x_scale
|
||||
quant_param[1].second, // y_zp
|
||||
quant_param[1].first[0], // y_scale
|
||||
0, // flags,
|
||||
&p);
|
||||
} else if (op_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
xstatus = xnn_create_softmax_nc_f32(
|
||||
channels,
|
||||
channels,
|
||||
channels,
|
||||
0, // flags,
|
||||
&p);
|
||||
xstatus = xnn_create_softmax_nc_f32(channels,
|
||||
channels,
|
||||
channels,
|
||||
0, // flags,
|
||||
&p);
|
||||
}
|
||||
|
||||
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_",
|
||||
OpTypeToString(op_type_), " failed. Status:", xstatus);
|
||||
op0_.reset(p);
|
||||
|
@ -194,39 +197,48 @@ Status Softmax::Compute(OpKernelContext* ctx) const {
|
|||
if (X_shape.Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
const size_t N = X_shape.SizeToDimension(axis_);
|
||||
// const size_t D = X_shape.SizeFromDimension(axis_); // the step D is 1
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_softmax_nc_qu8(
|
||||
op0_.get(),
|
||||
N,
|
||||
X->Data<uint8_t>(),
|
||||
Y->MutableData<uint8_t>(),
|
||||
t_pool);
|
||||
} else {
|
||||
status = xnn_setup_softmax_nc_f32(
|
||||
op0_.get(),
|
||||
N,
|
||||
X->Data<float>(),
|
||||
Y->MutableData<float>(),
|
||||
t_pool);
|
||||
}
|
||||
|
||||
auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8
|
||||
: xnn_reshape_softmax_nc_f32;
|
||||
status = reshape_fn(op0_.get(), N, threadpool);
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_",
|
||||
OpTypeToString(op_type_), " returned ", status);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_softmax_nc_", OpTypeToString(op_type_),
|
||||
" returned ", status);
|
||||
}
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
|
||||
if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_softmax_nc_qu8(op0_.get(), X->Data<uint8_t>(), Y->MutableData<uint8_t>());
|
||||
} else {
|
||||
status = xnn_setup_softmax_nc_f32(op0_.get(), X->Data<float>(), Y->MutableData<float>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_", OpTypeToString(op_type_),
|
||||
" returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 12, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 10, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Softmax);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Softmax);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(Softmax, kOnnxDomain, 13, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Softmax);
|
|
@ -2,10 +2,13 @@
|
|||
// Licensed under the MIT License.
|
||||
#include "core/providers/xnnpack/nn/average_pool.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/providers/utils.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
@ -90,6 +93,10 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
const auto& inputs = node_unit.Inputs();
|
||||
// use do {} while(false) so it's easier to set a breakpoint on the return
|
||||
do {
|
||||
if (node_unit.SinceVersion() < 7) {
|
||||
break;
|
||||
}
|
||||
|
||||
// AveragePool has 1 input.
|
||||
const auto& x_arg = inputs[0].node_arg;
|
||||
|
||||
|
@ -141,6 +148,11 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
break;
|
||||
}
|
||||
|
||||
// need dilations to all be 1
|
||||
if (!pool_attrs.default_dilations) {
|
||||
break;
|
||||
}
|
||||
|
||||
supported = true;
|
||||
} while (false);
|
||||
|
||||
|
@ -221,24 +233,47 @@ Status AveragePool::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
// setup allocator/automated dellocate for workspace
|
||||
size_t workspace_size = 0;
|
||||
size_t workspace_alignment = 0;
|
||||
xnn_allocator* allocator = GetStoredAllocator().second;
|
||||
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
|
||||
|
||||
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
|
||||
|
||||
auto reshape_fn = (avgpool_type_ == OpComputeType::op_compute_type_fp32)
|
||||
? xnn_reshape_average_pooling2d_nhwc_f32
|
||||
: xnn_reshape_average_pooling2d_nhwc_qu8;
|
||||
|
||||
auto status = reshape_fn(op0_.get(), N, H, W,
|
||||
&workspace_size, &workspace_alignment,
|
||||
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
|
||||
threadpool);
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_),
|
||||
" returned ", status);
|
||||
}
|
||||
|
||||
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
|
||||
|
||||
if (avgpool_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), N, H, W,
|
||||
X.Data<float>(), Y.MutableData<float>(),
|
||||
t_pool /*threadpool */);
|
||||
status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(),
|
||||
X.Data<float>(), Y.MutableData<float>());
|
||||
|
||||
} else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), N, H, W,
|
||||
X.Data<uint8_t>(), Y.MutableData<uint8_t>(),
|
||||
t_pool /*threadpool */);
|
||||
status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(),
|
||||
X.Data<uint8_t>(), Y.MutableData<uint8_t>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_",
|
||||
OpTypeToString(avgpool_type_), " returned ", status);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_),
|
||||
" returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
@ -246,8 +281,26 @@ Status AveragePool::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
AveragePool, kMSInternalNHWCDomain, 7, 9,
|
||||
kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
AveragePool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
AveragePool, kMSInternalNHWCDomain, 10, 10,
|
||||
kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
AveragePool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
AveragePool, kMSInternalNHWCDomain, 11, 18,
|
||||
kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
AveragePool);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
AveragePool, kMSInternalNHWCDomain, 11,
|
||||
AveragePool, kMSInternalNHWCDomain, 19,
|
||||
kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
AveragePool);
|
||||
|
|
|
@ -3,12 +3,13 @@
|
|||
|
||||
#include "conv.h"
|
||||
|
||||
#include "core/common/gsl.h"
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/transpose_helper.h"
|
||||
#include "core/providers/utils.h"
|
||||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/common/gsl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace xnnpack {
|
||||
|
@ -64,21 +65,48 @@ Status Conv::Compute(OpKernelContext* context) const {
|
|||
if (Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), N, H, W, X.Data<float>(), Y->MutableData<float>(),
|
||||
t_pool /*threadpool*/);
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), N, H, W, X.Data<int8_t>(), Y->MutableData<int8_t>(),
|
||||
t_pool /*threadpool*/);
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
// setup allocator/automated dellocate for workspace
|
||||
size_t workspace_size = 0;
|
||||
size_t workspace_alignment = 0;
|
||||
xnn_allocator* allocator = GetStoredAllocator().second;
|
||||
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
|
||||
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
|
||||
|
||||
auto reshape_fn = xnn_reshape_convolution2d_nhwc_f32;
|
||||
if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8;
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), N, H, W, X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
|
||||
t_pool /*threadpool*/);
|
||||
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8;
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
|
||||
status = xnn_setup_convolution2d_nhwc_qc8(op0_.get(), N, H, W, X.Data<int8_t>(), Y->MutableData<int8_t>(),
|
||||
t_pool /*threadpool*/);
|
||||
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w;
|
||||
}
|
||||
|
||||
auto status = reshape_fn(op0_.get(), N, H, W,
|
||||
&workspace_size, &workspace_alignment,
|
||||
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
|
||||
threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_),
|
||||
"returned ", status);
|
||||
}
|
||||
|
||||
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
|
||||
|
||||
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data<float>(),
|
||||
Y->MutableData<float>());
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data<int8_t>(),
|
||||
Y->MutableData<int8_t>());
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), workspace.get(), X.Data<uint8_t>(),
|
||||
Y->MutableData<uint8_t>());
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
|
||||
status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data<int8_t>(),
|
||||
Y->MutableData<int8_t>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
|
@ -86,7 +114,7 @@ Status Conv::Compute(OpKernelContext* context) const {
|
|||
OpTypeToString(conv_type_), "returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
@ -94,6 +122,10 @@ Status Conv::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Conv);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Conv);
|
||||
|
|
|
@ -23,7 +23,8 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
const std::optional<std::pair<float, float>>& clip_min_max,
|
||||
const Tensor& Weight, const Tensor* Bias,
|
||||
XnnpackOperator& op_uptr,
|
||||
xnn_caches_t caches_t,
|
||||
xnn_code_cache_t code_cache,
|
||||
xnn_weights_cache_t weights_cache,
|
||||
const OpQuantParam& quant_param,
|
||||
OpComputeType conv_type,
|
||||
bool is_transpose = false) {
|
||||
|
@ -75,7 +76,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
C, M, // input channel stride, output channel stride
|
||||
Weight.Data<float>(), B_data,
|
||||
foutput_min, foutput_max, flags,
|
||||
caches_t,
|
||||
code_cache, weights_cache,
|
||||
&p);
|
||||
} else if (conv_type == OpComputeType::op_compute_type_qs8) {
|
||||
const float output_scale = quant_param[2].first[0];
|
||||
|
@ -99,7 +100,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
quant_param[2].second, quant_param[2].first[0],
|
||||
output_min, output_max,
|
||||
flags,
|
||||
caches_t,
|
||||
code_cache, weights_cache,
|
||||
&p);
|
||||
} else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) {
|
||||
auto* B_data = Bias ? Bias->Data<int32_t>() : nullptr;
|
||||
|
@ -107,7 +108,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
const int8_t output_zero_point = quant_param[2].second;
|
||||
const int8_t output_min = xnn_u8s8_quantize<int8_t>(foutput_min, output_scale, output_zero_point);
|
||||
const int8_t output_max = xnn_u8s8_quantize<int8_t>(foutput_max, output_scale, output_zero_point);
|
||||
status = xnn_create_convolution2d_nhwc_qc8(
|
||||
status = xnn_create_convolution2d_nhwc_qs8_qc8w(
|
||||
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
|
||||
kernel_height, kernel_width,
|
||||
subsampling_height, subsampling_width,
|
||||
|
@ -123,7 +124,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
quant_param[2].second, quant_param[2].first[0],
|
||||
output_min, output_max,
|
||||
flags,
|
||||
caches_t,
|
||||
code_cache, weights_cache,
|
||||
&p);
|
||||
} else if (conv_type == OpComputeType::op_compute_type_qu8) {
|
||||
const auto* B_data = Bias ? Bias->Data<int32_t>() : nullptr;
|
||||
|
@ -148,15 +149,17 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr,
|
|||
quant_param[2].second, quant_param[2].first[0],
|
||||
output_min, output_max,
|
||||
flags,
|
||||
caches_t,
|
||||
code_cache, weights_cache,
|
||||
&p);
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Failed to create xnnpack kernel. xnn_create_",
|
||||
is_transpose ? "deconvolution2d" : "convolution2d", "_nhwc_",
|
||||
OpTypeToString(conv_type), " returned ", status);
|
||||
}
|
||||
|
||||
op_uptr.reset(p);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -296,6 +299,11 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&
|
|||
const onnxruntime::Node& node = node_unit.GetNode();
|
||||
// use do {} while(false) so it's easier to set a breakpoint on the return
|
||||
do {
|
||||
// Internal NHWC domain starts at opset 11
|
||||
if (node_unit.SinceVersion() < 11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Conv has at least 2 inputs.
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
const auto& x_arg = inputs[0].node_arg;
|
||||
|
@ -367,7 +375,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&
|
|||
}
|
||||
|
||||
ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
|
||||
: XnnpackKernel(info),
|
||||
: XnnpackKernel(info, /*enable_caches*/ true),
|
||||
conv_attrs_(info),
|
||||
conv_transpose_attrs_(info),
|
||||
convbase_attrs_ref_(is_transpose ? conv_transpose_attrs_ : conv_attrs_),
|
||||
|
@ -383,16 +391,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
|
|||
}
|
||||
}
|
||||
}
|
||||
// xnnpack cache_code, unfortunately these definitions are only available in xnnpack/cache.h,
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
#if XNN_PLATFORM_JIT
|
||||
xnn_init_code_cache(&code_cache_);
|
||||
xnn_caches_.code_cache = &code_cache_;
|
||||
#endif
|
||||
// TODO(Jicwen) enable weight-cache and code-cache
|
||||
xnn_init_weights_cache(&weights_cache_);
|
||||
xnn_caches_.weights_cache = &weights_cache_;
|
||||
#endif
|
||||
|
||||
const auto& node{Node()};
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const NodeArg& X = *input_defs[0];
|
||||
|
@ -477,11 +476,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
|
|||
Status ConvBase::CreateKernel() {
|
||||
auto ret = CreateXnnpackKernel(&convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_,
|
||||
B_, op0_,
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
&xnn_caches_,
|
||||
#else
|
||||
0,
|
||||
#endif
|
||||
GetCodeCache(), GetWeightsCache(),
|
||||
quant_param_, conv_type_, is_transpose_);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -39,14 +39,6 @@ class ConvBase : public XnnpackKernel {
|
|||
std::optional<std::pair<float, float>> clip_min_max_;
|
||||
|
||||
XnnpackOperator op0_ = nullptr;
|
||||
// we can't have the definition here because we can't import xnnpack/cache.h
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
#if XNN_PLATFORM_JIT
|
||||
xnn_code_cache code_cache_;
|
||||
#endif
|
||||
xnn_caches xnn_caches_ = {0, 0};
|
||||
xnn_weights_cache weights_cache_;
|
||||
#endif
|
||||
OpQuantParam quant_param_;
|
||||
OpComputeType conv_type_ = OpComputeType::op_compute_type_invalid;
|
||||
};
|
||||
|
|
|
@ -81,29 +81,34 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
|
|||
if (Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
auto output_pad_0 = gsl::narrow_cast<uint32_t>(conv_transpose_attrs_.output_padding[0]);
|
||||
auto output_pad_1 = gsl::narrow_cast<uint32_t>(conv_transpose_attrs_.output_padding[1]);
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_f32(
|
||||
op0_.get(), N, H, W,
|
||||
output_pad_0,
|
||||
output_pad_1, X.Data<float>(), Y->MutableData<float>(),
|
||||
t_pool /*threadpool*/);
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_qs8(
|
||||
op0_.get(), N, H, W,
|
||||
output_pad_0,
|
||||
output_pad_1, X.Data<int8_t>(), Y->MutableData<int8_t>(),
|
||||
t_pool /*threadpool*/);
|
||||
|
||||
auto reshape_fn = xnn_reshape_deconvolution2d_nhwc_f32;
|
||||
if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8;
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_qu8(
|
||||
op0_.get(), N, H, W,
|
||||
output_pad_0,
|
||||
output_pad_1, X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
|
||||
t_pool /*threadpool*/);
|
||||
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8;
|
||||
}
|
||||
|
||||
status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1,
|
||||
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
|
||||
threadpool);
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_deconvolution2d_nhwc_",
|
||||
OpTypeToString(conv_type_), " returned ", status);
|
||||
}
|
||||
|
||||
if (conv_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
|
||||
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
|
@ -111,7 +116,7 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
|
|||
OpTypeToString(conv_type_), " returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
const onnxruntime::Node& node = node_unit.GetNode();
|
||||
// use do {} while(false) so it's easier to set a breakpoint on the return
|
||||
do {
|
||||
if (node_unit.SinceVersion() < 8) {
|
||||
break;
|
||||
}
|
||||
|
||||
// MaxPool has 1 input.
|
||||
auto input_defs = node.InputDefs();
|
||||
const auto& x_arg = *input_defs[0];
|
||||
|
@ -220,20 +224,29 @@ Status MaxPool::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32;
|
||||
if (maxpool_type_ == OpComputeType::op_compute_type_qu8)
|
||||
reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8;
|
||||
else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8;
|
||||
}
|
||||
|
||||
auto status = reshape_fn(op0_.get(), N, H, W,
|
||||
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
|
||||
threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_max_pooling2d_nhwc_",
|
||||
OpTypeToString(maxpool_type_), " returned ", status);
|
||||
}
|
||||
|
||||
if (maxpool_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), N, H, W,
|
||||
X.Data<float>(), Y->MutableData<float>(),
|
||||
t_pool /*threadpool */);
|
||||
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), N, H, W,
|
||||
X.Data<uint8_t>(), Y->MutableData<uint8_t>(),
|
||||
t_pool /*threadpool */);
|
||||
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
|
||||
} else {
|
||||
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), N, H, W,
|
||||
X.Data<int8_t>(), Y->MutableData<int8_t>(),
|
||||
t_pool /*threadpool */);
|
||||
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
|
@ -241,7 +254,7 @@ Status MaxPool::Compute(OpKernelContext* context) const {
|
|||
OpTypeToString(maxpool_type_), " returned ", status);
|
||||
}
|
||||
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
@ -249,12 +262,24 @@ Status MaxPool::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
MaxPool);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/xnnpack/nn/resize.h"
|
||||
#include "core/providers/xnnpack/tensor/resize.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
@ -10,6 +10,7 @@
|
|||
#include "core/common/inlined_containers_fwd.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace xnnpack {
|
||||
|
@ -18,26 +19,67 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
const GraphViewer& graph_viewer) {
|
||||
bool supported = false;
|
||||
do {
|
||||
if (node_unit.SinceVersion() < 10) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Resize has 1-4 input.
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
const auto& x_arg = inputs[0].node_arg;
|
||||
|
||||
const auto* x_type = x_arg.TypeAsProto();
|
||||
if (x_type == nullptr ||
|
||||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
|
||||
if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const auto* x_shape = x_arg.Shape();
|
||||
//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) or
|
||||
// 4-D input with outermost and innermost scales as 1 (NHWC)
|
||||
// but we just support 4-d tensor for now, and the channel must be known.
|
||||
|
||||
// 'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) can be supported.
|
||||
// we only support 4-d tensor for now, and the channel must be known.
|
||||
// we assume the input in NCHW for this test.
|
||||
if (!x_shape || x_shape->dim_size() != 4 || x_shape->dim(1).dim_value() <= 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// validate it is in fact NCHW
|
||||
//
|
||||
// opset 10 had `scales` as input 1 and no sizes. later opsets added roi as input 1 followed by scales and sizes.
|
||||
auto opset_version = node_unit.SinceVersion();
|
||||
size_t scale_idx = opset_version == 10 ? 1 : 2;
|
||||
size_t size_idx = 3;
|
||||
|
||||
// onnx shape inferencing validates that one and not both of sizes and scales are provided
|
||||
const auto* scale_tensor = inputs.size() >= scale_idx + 1
|
||||
? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true)
|
||||
: nullptr;
|
||||
const auto* size_tensor = opset_version > 10 && inputs.size() >= size_idx + 1
|
||||
? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true)
|
||||
: nullptr;
|
||||
|
||||
// if both scales and sizes are nullptr the one that was provided was not a constant initializer
|
||||
if (!scale_tensor && !size_tensor) {
|
||||
break;
|
||||
}
|
||||
|
||||
// check the scale for the second dim is 1 or the size of the second dim matches the input shape.
|
||||
// if not, it is not the C dim as a Resize will not change the number of channels.
|
||||
InlinedVector<float> scale(4, 1.0F);
|
||||
if (scale_tensor) {
|
||||
const Initializer scale_val(*scale_tensor, node_unit.ModelPath());
|
||||
if (scale_val.DataAsSpan<float>()[1] != 1.0F) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (size_tensor) {
|
||||
const Initializer size_val(*size_tensor, node_unit.ModelPath());
|
||||
if (size_val.DataAsSpan<int64_t>()[1] != x_shape->dim(1).dim_value()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const auto* output_shape = node_unit.Outputs()[0].node_arg.Shape();
|
||||
bool length_resized_compatible_pytorch_half_pixel = true;
|
||||
// when length_resized > 1, there is no difference between pytorch_half_pixel and half_pixel
|
||||
|
@ -48,18 +90,11 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
// if coordinate_transformation_mode is "pytorch_half_pixel",
|
||||
// x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0
|
||||
//
|
||||
if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(1).dim_value() <= 1) {
|
||||
if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(3).dim_value() <= 1) {
|
||||
// we don't know the output H or W so we don't know if it will be compatible
|
||||
length_resized_compatible_pytorch_half_pixel = false;
|
||||
}
|
||||
|
||||
// Refer to onnxruntime/core/providers/cpu/tensor/upsamplebase.h,
|
||||
size_t scale_idx = 2;
|
||||
size_t size_idx = 3;
|
||||
auto opset_version = node_unit.SinceVersion();
|
||||
if (opset_version == 10) {
|
||||
scale_idx = 1;
|
||||
}
|
||||
|
||||
ProtoHelperNodeContext nc(node_unit.GetNode());
|
||||
OpNodeProtoHelper info(&nc);
|
||||
|
||||
|
@ -78,6 +113,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
|
||||
std::vector<int64_t> axes;
|
||||
if (info.GetAttrs<int64_t>("axes", axes).IsOK() && axes.size() > 0) {
|
||||
// TODO: We should be able to handle this if required
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -95,9 +131,10 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
// Coordinate transformation mode attr was introduced in version 11.
|
||||
// before that asymmetric mode was the only available transformation mode
|
||||
std::string coordinate_transform_mode_name =
|
||||
opset_version > 10
|
||||
? info.GetAttrOrDefault<std::string>("coordinate_transformation_mode", "half_pixel")
|
||||
: "asymmetric";
|
||||
opset_version > 10 ? info.GetAttrOrDefault<std::string>("coordinate_transformation_mode", "half_pixel")
|
||||
: "asymmetric";
|
||||
|
||||
// TODO: Opset 19 added half_pixel_symmetric. Need to see if that can be supported.
|
||||
|
||||
if (coordinate_transform_mode_name != "asymmetric" &&
|
||||
coordinate_transform_mode_name != "half_pixel" &&
|
||||
|
@ -106,59 +143,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
break;
|
||||
}
|
||||
|
||||
auto exclude_outside = info.GetAttrOrDefault<int64_t>("exclude_outside", 0) == 0 ? false : true;
|
||||
if (exclude_outside) {
|
||||
break;
|
||||
}
|
||||
|
||||
// roi only takes effect when coordinate_transformation_mode is "tf_crop_and_resize"
|
||||
|
||||
// size or scales shouldnt't be provided in the same time but should at least be provided one of them
|
||||
const auto* scale_tensor = inputs.size() >= scale_idx + 1
|
||||
? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true)
|
||||
: nullptr;
|
||||
const auto* size_tensor = inputs.size() >= size_idx + 1
|
||||
? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true)
|
||||
: nullptr;
|
||||
|
||||
bool has_size = false;
|
||||
bool has_scale = false;
|
||||
InlinedVector<float> scale(4, 1.0F);
|
||||
if (scale_tensor) {
|
||||
const Initializer scale_val(*scale_tensor, node_unit.ModelPath());
|
||||
auto scale_span = scale_val.DataAsSpan<float>();
|
||||
if (scale_span.size() == 4) {
|
||||
has_scale = true;
|
||||
std::copy(scale_span.begin(), scale_span.end(), scale.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (size_tensor) {
|
||||
auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*x_shape);
|
||||
const Initializer size_val(*size_tensor, node_unit.ModelPath());
|
||||
|
||||
auto size_span = size_val.DataAsSpan<int64_t>();
|
||||
if (size_span.size() == 4) {
|
||||
has_size = true;
|
||||
scale = {size_span[0] / static_cast<float>(input_shape[0]),
|
||||
size_span[1] / static_cast<float>(input_shape[1]),
|
||||
size_span[2] / static_cast<float>(input_shape[2]),
|
||||
size_span[3] / static_cast<float>(input_shape[3])};
|
||||
}
|
||||
}
|
||||
|
||||
if ((has_size && has_scale) || (!has_size && !has_scale)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (scale[0] != 1.0F || (scale[1] != 1.0F && scale[3] != 1.0F)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only support xnn_create_resize_bilinear2d_nchw_f32
|
||||
const bool is_NHWC = scale[3] == 1.0F;
|
||||
if (!is_NHWC && (x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
|
||||
x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
|
||||
if (info.GetAttrOrDefault<int64_t>("exclude_outside", 0) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -210,8 +195,7 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf
|
|||
}
|
||||
}
|
||||
|
||||
is_NHWC_ = scales_[3] == 1.0F;
|
||||
int64_t channels = x_shape->dim(is_NHWC_ ? 3 : 1).dim_value();
|
||||
int64_t channels = x_shape->dim(3).dim_value();
|
||||
|
||||
uint32_t flags = 0;
|
||||
ORT_ENFORCE(mode_ == UpsampleMode::LINEAR, "only support bilinear resize");
|
||||
|
@ -225,18 +209,16 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf
|
|||
xnn_status xstatus = xnn_status_invalid_state;
|
||||
struct xnn_operator* p = nullptr;
|
||||
if (op_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
auto create_func = is_NHWC_ ? xnn_create_resize_bilinear2d_nhwc_f32 : xnn_create_resize_bilinear2d_nchw_f32;
|
||||
xstatus = create_func(
|
||||
channels, channels, channels, flags, &p);
|
||||
xstatus = xnn_create_resize_bilinear2d_nhwc_f32(channels, channels, channels, flags, &p);
|
||||
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
xstatus = xnn_create_resize_bilinear2d_nhwc_u8(
|
||||
channels, channels, channels, flags, &p);
|
||||
xstatus = xnn_create_resize_bilinear2d_nhwc_u8(channels, channels, channels, flags, &p);
|
||||
} else {
|
||||
xstatus = xnn_create_resize_bilinear2d_nhwc_s8(
|
||||
channels, channels, channels, flags, &p);
|
||||
xstatus = xnn_create_resize_bilinear2d_nhwc_s8(channels, channels, channels, flags, &p);
|
||||
}
|
||||
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_",
|
||||
OpTypeToString(op_type_), " failed. Status:", xstatus);
|
||||
|
||||
ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:",
|
||||
xstatus);
|
||||
|
||||
op0_.reset(p);
|
||||
}
|
||||
|
||||
|
@ -245,48 +227,56 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input,
|
|||
const TensorShapeVector& output_dims) const {
|
||||
const auto& X_shape = input->Shape();
|
||||
auto N = X_shape[0];
|
||||
auto H = is_NHWC_ ? X_shape[1] : X_shape[2];
|
||||
auto W = is_NHWC_ ? X_shape[2] : X_shape[3];
|
||||
auto H = X_shape[1];
|
||||
auto W = X_shape[2];
|
||||
Tensor* output = ctx->Output(0, TensorShape(output_dims));
|
||||
|
||||
pthreadpool_t t_pool = GetThreadPool();
|
||||
xnn_status status = xnn_status_invalid_state;
|
||||
if (op_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
auto oH = is_NHWC_ ? output_dims[1] : output_dims[2];
|
||||
auto oW = is_NHWC_ ? output_dims[2] : output_dims[3];
|
||||
auto setup_func = is_NHWC_ ? xnn_setup_resize_bilinear2d_nhwc_f32 : xnn_setup_resize_bilinear2d_nchw_f32;
|
||||
status = setup_func(
|
||||
op0_.get(),
|
||||
N,
|
||||
H, W, oH, oW,
|
||||
input->Data<float>(),
|
||||
output->MutableData<float>(),
|
||||
t_pool);
|
||||
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_resize_bilinear2d_nhwc_u8(
|
||||
op0_.get(),
|
||||
N,
|
||||
H, W, output_dims[1], output_dims[2],
|
||||
input->Data<uint8_t>(),
|
||||
output->MutableData<uint8_t>(),
|
||||
t_pool);
|
||||
} else {
|
||||
status = xnn_setup_resize_bilinear2d_nhwc_s8(
|
||||
op0_.get(),
|
||||
N,
|
||||
H, W, output_dims[1], output_dims[2],
|
||||
input->Data<int8_t>(),
|
||||
output->MutableData<int8_t>(),
|
||||
t_pool);
|
||||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
// setup allocator/automated dellocate for workspace
|
||||
size_t workspace_size = 0;
|
||||
size_t workspace_alignment = 0;
|
||||
xnn_allocator* allocator = GetStoredAllocator().second;
|
||||
auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); };
|
||||
std::unique_ptr<void, decltype(deallocator)> workspace(nullptr, deallocator);
|
||||
|
||||
auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32;
|
||||
if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8;
|
||||
} else if (op_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8;
|
||||
}
|
||||
|
||||
auto status = reshape_fn(op0_.get(), N, H, W, output_dims[1], output_dims[2],
|
||||
&workspace_size, &workspace_alignment, threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_),
|
||||
" returned ", status);
|
||||
}
|
||||
|
||||
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));
|
||||
|
||||
if (op_type_ == OpComputeType::op_compute_type_fp32) {
|
||||
status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data<float>(),
|
||||
output->MutableData<float>());
|
||||
} else if (op_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data<uint8_t>(),
|
||||
output->MutableData<uint8_t>());
|
||||
} else {
|
||||
status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data<int8_t>(),
|
||||
output->MutableData<int8_t>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_resize_bilinear2d_nhwc_",
|
||||
OpTypeToString(op_type_), " returned ", status);
|
||||
}
|
||||
status = xnn_run_operator(op0_.get(), t_pool);
|
||||
|
||||
status = xnn_run_operator(op0_.get(), threadpool);
|
||||
if (status != xnn_status_success) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -315,29 +305,29 @@ Status Resize::Compute(OpKernelContext* ctx) const {
|
|||
return ComputeInternal(ctx, X, output_shape);
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 10, 10, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
Resize);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 11, 12, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
Resize);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 13, 17, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 18, 18, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider,
|
||||
ONNX_OPERATOR_KERNEL_EX(Resize, kMSInternalNHWCDomain, 19, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
|
@ -31,7 +31,6 @@ class Resize : public UpsampleBase, public XnnpackKernel {
|
|||
const TensorShapeVector& output_dims) const;
|
||||
|
||||
private:
|
||||
bool is_NHWC_;
|
||||
XnnpackOperator op0_;
|
||||
TensorShapeVector output_dims_;
|
||||
OpComputeType op_type_ = OpComputeType::op_compute_type_invalid;
|
|
@ -27,88 +27,117 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
return info;
|
||||
}
|
||||
|
||||
#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, End, Op)>
|
||||
#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op, Domain) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)>
|
||||
|
||||
#define KERNEL_CREATE_INFO(Start, Op) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, Op)>
|
||||
#define KERNEL_CREATE_INFO(Start, Op, Domain) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)>
|
||||
|
||||
#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, type, Op)>
|
||||
#define KERNEL_CREATE_INFO_TYPED(Start, Type, Op, Domain) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>
|
||||
|
||||
// Layout sensitive operators in NHWC domain
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearAveragePool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider,
|
||||
kDynamicDomainByCreate, 1, QLinearSoftmax);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearAveragePool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, Resize);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
|
||||
// ONNX operators
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 10, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 8, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 12, MatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 10, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Softmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax);
|
||||
|
||||
// Internal domain
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list becoming empty after ops-reducing
|
||||
|
||||
KERNEL_CREATE_INFO(11, Conv),
|
||||
KERNEL_CREATE_INFO(11, ConvTranspose),
|
||||
KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose),
|
||||
KERNEL_CREATE_INFO(1, QLinearConvTranspose),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool),
|
||||
KERNEL_CREATE_INFO(12, MaxPool),
|
||||
KERNEL_CREATE_INFO(11, AveragePool),
|
||||
// layout sensitive. nodes will be moved to kMSInternalNHWCDomain by layout transformation
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 9, AveragePool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(10, 10, AveragePool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 18, AveragePool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO(19, AveragePool, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(1, 10, Conv, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO(11, Conv, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO(11, ConvTranspose, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(8, 9, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(10, 10, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO(12, MaxPool, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO(1, QLinearConvTranspose, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(10, 10, Resize, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 12, Resize, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(13, 17, Resize, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(18, 18, Resize, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO(19, Resize, kMSInternalNHWCDomain),
|
||||
|
||||
// layout insensitive, use ONNX-domain directly
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul)>,
|
||||
KERNEL_CREATE_INFO_VERSIONED(1, 10, Softmax, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 12, Softmax, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO(13, Softmax, kOnnxDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 8, Gemm, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 10, Gemm, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(11, 12, Gemm, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO(13, Gemm, kOnnxDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_VERSIONED(1, 8, MatMul, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED(9, 12, MatMul, kOnnxDomain),
|
||||
KERNEL_CREATE_INFO(13, MatMul, kOnnxDomain),
|
||||
|
||||
// quantization op
|
||||
KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv),
|
||||
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv),
|
||||
KERNEL_CREATE_INFO(1, QLinearAveragePool),
|
||||
BuildKernelCreateInfo<
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax)>,
|
||||
KERNEL_CREATE_INFO(1, QLinearAveragePool, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
@ -26,13 +26,15 @@ void xnn_deallocate(void* context, void* pointer) {
|
|||
}
|
||||
|
||||
void* xnn_aligned_allocate(void* context, size_t alignment, size_t size) {
|
||||
if (size == 0)
|
||||
return nullptr;
|
||||
|
||||
#if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__)
|
||||
ORT_ENFORCE(alignment <= 2 * sizeof(void*));
|
||||
return xnn_allocate(context, size);
|
||||
#else
|
||||
void* ptr = xnn_allocate(context, size);
|
||||
ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0,
|
||||
" xnnpack wants to allocate a space with ", alignment, "bytes aligned. But it's not satisfied");
|
||||
ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0, "xnnpack allocation was not aligned to ", alignment, " bytes.");
|
||||
// if ptr is not aligned, we have to find a way to return a aligned ptr and store the original ptr
|
||||
return ptr;
|
||||
#endif
|
||||
|
|
|
@ -5,6 +5,47 @@ struct xnn_allocator;
|
|||
namespace onnxruntime {
|
||||
namespace xnnpack {
|
||||
|
||||
// copy #define logic from XNNPACK src/xnnpack/common.h to determine workspace alignment
|
||||
#if defined(__APPLE__)
|
||||
#include <TargetConditionals.h>
|
||||
#endif
|
||||
|
||||
#if defined(__i386__) || defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(_M_IX86)
|
||||
#define XNN_ARCH_X86 1
|
||||
#else
|
||||
#define XNN_ARCH_X86 0
|
||||
#endif
|
||||
|
||||
#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) && !defined(_M_ARM64EC)
|
||||
#define XNN_ARCH_X86_64 1
|
||||
#else
|
||||
#define XNN_ARCH_X86_64 0
|
||||
#endif
|
||||
|
||||
#if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__)
|
||||
#define XNN_ARCH_WASM 1
|
||||
#else
|
||||
#define XNN_ARCH_WASM 0
|
||||
#endif
|
||||
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && TARGET_OS_IPHONE)
|
||||
#define XNN_PLATFORM_MOBILE 1
|
||||
#else
|
||||
#define XNN_PLATFORM_MOBILE 0
|
||||
#endif
|
||||
|
||||
#if XNN_ARCH_WASM
|
||||
#define XNN_ALLOCATION_ALIGNMENT 4
|
||||
#elif XNN_ARCH_X86 || XNN_ARCH_X86_64
|
||||
#if XNN_PLATFORM_MOBILE
|
||||
#define XNN_ALLOCATION_ALIGNMENT 32
|
||||
#else
|
||||
#define XNN_ALLOCATION_ALIGNMENT 64
|
||||
#endif
|
||||
#else
|
||||
#define XNN_ALLOCATION_ALIGNMENT 16
|
||||
#endif
|
||||
|
||||
std::pair<AllocatorPtr&, xnn_allocator*> GetStoredAllocator();
|
||||
|
||||
} // namespace xnnpack
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/xnnpack/xnnpack_execution_provider.h"
|
||||
#include "xnnpack.h"
|
||||
|
||||
struct pthreadpool;
|
||||
|
||||
|
@ -12,18 +13,59 @@ namespace xnnpack {
|
|||
|
||||
class XnnpackKernel : public OpKernel {
|
||||
public:
|
||||
explicit XnnpackKernel(const OpKernelInfo& info)
|
||||
: OpKernel(info),
|
||||
xnnpack_threadpool_(
|
||||
static_cast<const XnnpackExecutionProvider*>(info.GetExecutionProvider())
|
||||
->GetPrivateThreadPool()) {
|
||||
explicit XnnpackKernel(const OpKernelInfo& info, bool enable_caches = false)
|
||||
: OpKernel{info},
|
||||
xnnpack_threadpool_{
|
||||
static_cast<const XnnpackExecutionProvider*>(info.GetExecutionProvider())->GetPrivateThreadPool()},
|
||||
caches_{enable_caches} {
|
||||
}
|
||||
[[nodiscard]] pthreadpool* GetThreadPool() const {
|
||||
return xnnpack_threadpool_;
|
||||
}
|
||||
|
||||
// see comment below about enabling code cache
|
||||
// xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();}
|
||||
xnn_code_cache_t GetCodeCache() { return nullptr; }
|
||||
xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); }
|
||||
|
||||
private:
|
||||
pthreadpool* xnnpack_threadpool_;
|
||||
|
||||
// Helper class to wrap usage of the XNNPACK weights and code caches.
|
||||
// NOTE: Currently creating/freeing the code cache is not exposed via the public xnnpack.h header so usage is
|
||||
// commented out. If we need to use it, we'll need to add the 'src' directory of XNNPACK to the include path
|
||||
// and #include "xnnpack/cache.h"
|
||||
struct Caches {
|
||||
Caches(bool enable)
|
||||
: // auto_code_cache(nullptr, xnn_release_code_cache),
|
||||
auto_weights_cache(nullptr, xnn_delete_weights_cache) {
|
||||
if (enable) {
|
||||
#ifdef XNN_CACHE_ENABLE
|
||||
xnn_status status = xnn_status_success;
|
||||
#if XNN_PLATFORM_JIT
|
||||
// status = xnn_init_code_cache(&code_cache_);
|
||||
// ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");)
|
||||
// auto_code_cache.reset(&code_cache_);
|
||||
#endif
|
||||
// status = xnn_init_weights_cache(&weights_cache_);
|
||||
xnn_weights_cache_t weights_cache = nullptr;
|
||||
status = xnn_create_weights_cache(&weights_cache, 0);
|
||||
ORT_ENFORCE(status == xnn_status_success, "Failed to create XNNPACK weights cache");
|
||||
auto_weights_cache.reset(weights_cache);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// std::unique_ptr<xnn_code_cache, decltype(&xnn_release_code_cache)> auto_code_cache;
|
||||
std::unique_ptr<xnn_weights_cache, decltype(&xnn_delete_weights_cache)> auto_weights_cache;
|
||||
|
||||
// private:
|
||||
// #if defined(XNN_CACHE_ENABLE) && XNN_PLATFORM_JIT
|
||||
// xnn_code_cache code_cache_;
|
||||
// #endif
|
||||
};
|
||||
|
||||
Caches caches_;
|
||||
};
|
||||
} // namespace xnnpack
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -399,6 +399,8 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
const std::vector<std::unique_ptr<IExecutionProvider>>& execution_providers,
|
||||
const std::vector<std::shared_ptr<CustomRegistry>>* custom_registries) {
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
const KernelRegistry::TypeConstraintMap type_constraint_map{};
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == kConstant)
|
||||
continue;
|
||||
|
@ -426,13 +428,28 @@ bool SetEpsForAllNodes(Graph& graph,
|
|||
break;
|
||||
}
|
||||
|
||||
// check the internal NHWC domain if EP requests NHWC as it may only have a kernel registered in that domain
|
||||
if (ep->GetPreferredLayout() == DataLayout::NHWC) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
auto status = ep->GetKernelRegistry()->TryFindKernel(ep->Type(),
|
||||
std::string_view(node.OpType()),
|
||||
std::string_view(kMSInternalNHWCDomain),
|
||||
node.SinceVersion(),
|
||||
type_constraint_map,
|
||||
&kci);
|
||||
if (status.IsOK() && kci != nullptr) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check the EP has an impl for the node from custom_registries
|
||||
if (custom_registries != nullptr &&
|
||||
std::any_of(custom_registries->cbegin(), custom_registries->cend(),
|
||||
[&](auto reg) { return KernelRegistry::HasImplementationOf(
|
||||
*reg->GetKernelRegistry(),
|
||||
node, ep->Type(),
|
||||
kernel_type_str_resolver); })) {
|
||||
[&](auto reg) {
|
||||
return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(),
|
||||
kernel_type_str_resolver);
|
||||
})) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
|
@ -760,7 +777,7 @@ void BaseTester::ExecuteModelForEps(
|
|||
for (const auto& ep : execution_providers) {
|
||||
providers.append(ep->Type() + " ");
|
||||
}
|
||||
LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << "were unable to run the model.";
|
||||
LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << " were unable to run the model.";
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ struct ConvOp {
|
|||
std::unique_ptr<CompareOpTester> get_test() {
|
||||
RandomValueGenerator random{};
|
||||
|
||||
auto test = std::make_unique<CompareOpTester>("Conv", 7);
|
||||
auto test = std::make_unique<CompareOpTester>("Conv", 11); // internal NHWC domain starts at opset 11
|
||||
std::vector<T> input_data = random.Uniform<T>(input_dims, 0.0f, 1.0f);
|
||||
|
||||
std::vector<int64_t> weight_dims{channels, input_dims[1] / group, kernel_shape[0], kernel_shape[1]};
|
||||
|
|
|
@ -203,7 +203,8 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {
|
|||
}
|
||||
|
||||
TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model.onnx";
|
||||
// the internal NHWC domain supports opset 11 and later
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx";
|
||||
|
||||
SessionOptions so;
|
||||
// set this if you want to manually inspect the optimized model
|
||||
|
|
Двоичный файл не отображается.
|
@ -11,7 +11,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
|
||||
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
|
||||
version: 1.0.104
|
||||
version: 1.0.107
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# The private ADO project
|
||||
|
@ -22,7 +22,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
|
||||
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
|
||||
version: 1.0.104
|
||||
version: 1.0.107
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# You can add more ADO accounts at here.
|
||||
|
|
Загрузка…
Ссылка в новой задаче