Adding CUDNN Frontend and use for CUDA NN Convolution (#19470)
### Description Added CUDNN Frontend and used it for NHWC convolutions, and optionally fuse activation. #### Backward compatible - For model existed with FusedConv, model can still run. - If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used. #### Major Changes - For cuDNN 9, we will enable cudnn frontend to fuse convolution and bias when a provider option `fuse_conv_bias=1`. - Remove the fusion of FusedConv from graph transformer for CUDA provider, so there will not be FusedConv be added to graph for CUDA EP in the future. - Update cmake files regarding to cudnn settings. The search order of CUDNN installation in build are like the following: * environment variable `CUDNN_PATH` * `onnxruntime_CUDNN_HOME` cmake extra defines. If a build starts from build.py/build.sh, user can pass it through `--cudnn_home` parameter, or by environment variable `CUDNN_HOME` if `--cudnn_home` not used. * cudnn python package installation directory like python3.xx/site-packages/nvidia/cudnn * CUDA installation path #### Potential Issues - If ORT is built with cuDNN 8, FusedConv fusion is no longer done automatically, so some model might have performance regression. If user still wants FusedConv operator for performance reason, they can still have multiple ways to walkaround: like use older version of onnxruntime; or use older version of ORT to save optimized onnx, then run with latest version of ORT. We believe that majority users have moved to cudnn 9 when 1.20 release (since the default in ORT and PyTorch is cudnn 9 for 3 months when 1.20 release), so the impact is small. - cuDNN graph uses TF32 by default, and user cannot disable TF32 through the use_tf32 cuda provider option. If user encounters accuracy issue (like in testing), user has to set environment variable `NVIDIA_TF32_OVERRIDE=0` to disable TF32. Need update the document of use_tf32 later. #### Follow ups This is one of PRs that target to enable NHWC convolution in CUDA EP by default if device supports it. There are other changes will follow up to make it possible. (1) Enable `prefer_nhwc` by default for device with sm >= 70. (2) Change `fuse_conv_bias=1` by default after more testing. (3) Add other NHWC operators (like Resize or UpSample). ### Motivation and Context The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution with the pointwise bias operation. On the [NVIDIA ResNet50](https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/) we get a performance boost from 49.1144 ms to 42.4643 ms per inference on a 2560x1440 input (`onnxruntime_perf_test -e cuda -I -q -r 100-d 1 -i 'prefer_nhwc|1' resnet50.onnx`). --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: Maximilian Mueller <maximilianm@nvidia.com>
This commit is contained in:
Родитель
0e708de4fc
Коммит
1391354265
|
@ -351,6 +351,16 @@
|
|||
},
|
||||
"comments": "directx_headers"
|
||||
}
|
||||
},
|
||||
{
|
||||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019",
|
||||
"repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git"
|
||||
},
|
||||
"comments": "cudnn_frontend"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -729,9 +729,6 @@ set(ORT_PROVIDER_FLAGS)
|
|||
set(ORT_PROVIDER_CMAKE_FLAGS)
|
||||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
if (onnxruntime_USE_CUDA_NHWC_OPS)
|
||||
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
|
||||
endif()
|
||||
enable_language(CUDA)
|
||||
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
|
||||
|
||||
|
@ -1445,9 +1442,6 @@ if (onnxruntime_USE_CUDA)
|
|||
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
|
||||
endif()
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
|
||||
endif()
|
||||
if (NOT CMAKE_CUDA_ARCHITECTURES)
|
||||
if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
|
||||
# Support for Jetson/Tegra ARM devices
|
||||
|
|
|
@ -58,3 +58,4 @@ utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240
|
|||
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
|
||||
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
|
||||
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
|
||||
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
|
||||
|
||||
find_path(
|
||||
CUDNN_INCLUDE_DIR cudnn.h
|
||||
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS}
|
||||
PATH_SUFFIXES include
|
||||
REQUIRED
|
||||
)
|
||||
|
||||
file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
|
||||
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
|
||||
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
|
||||
|
||||
function(find_cudnn_library NAME)
|
||||
find_library(
|
||||
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
|
||||
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib/x64 lib
|
||||
REQUIRED
|
||||
)
|
||||
|
||||
if(${NAME}_LIBRARY)
|
||||
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
|
||||
set_target_properties(
|
||||
CUDNN::${NAME} PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
|
||||
IMPORTED_LOCATION ${${NAME}_LIBRARY}
|
||||
)
|
||||
message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
|
||||
else()
|
||||
message(STATUS "${NAME} not found.")
|
||||
endif()
|
||||
|
||||
|
||||
endfunction()
|
||||
|
||||
find_cudnn_library(cudnn)
|
||||
|
||||
include (FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(
|
||||
LIBRARY REQUIRED_VARS
|
||||
CUDNN_INCLUDE_DIR cudnn_LIBRARY
|
||||
)
|
||||
|
||||
if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)
|
||||
|
||||
message(STATUS "cuDNN: ${cudnn_LIBRARY}")
|
||||
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")
|
||||
|
||||
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
else()
|
||||
|
||||
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
target_include_directories(
|
||||
CUDNN::cudnn_all
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
CUDNN::cudnn_all
|
||||
INTERFACE
|
||||
CUDNN::cudnn
|
||||
)
|
||||
|
||||
if(CUDNN_MAJOR_VERSION EQUAL 8)
|
||||
find_cudnn_library(cudnn_adv_infer)
|
||||
find_cudnn_library(cudnn_adv_train)
|
||||
find_cudnn_library(cudnn_cnn_infer)
|
||||
find_cudnn_library(cudnn_cnn_train)
|
||||
find_cudnn_library(cudnn_ops_infer)
|
||||
find_cudnn_library(cudnn_ops_train)
|
||||
|
||||
target_link_libraries(
|
||||
CUDNN::cudnn_all
|
||||
INTERFACE
|
||||
CUDNN::cudnn_adv_train
|
||||
CUDNN::cudnn_ops_train
|
||||
CUDNN::cudnn_cnn_train
|
||||
CUDNN::cudnn_adv_infer
|
||||
CUDNN::cudnn_cnn_infer
|
||||
CUDNN::cudnn_ops_infer
|
||||
)
|
||||
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
|
||||
find_cudnn_library(cudnn_cnn)
|
||||
find_cudnn_library(cudnn_adv)
|
||||
find_cudnn_library(cudnn_graph)
|
||||
find_cudnn_library(cudnn_ops)
|
||||
find_cudnn_library(cudnn_engines_runtime_compiled)
|
||||
find_cudnn_library(cudnn_engines_precompiled)
|
||||
find_cudnn_library(cudnn_heuristic)
|
||||
|
||||
target_link_libraries(
|
||||
CUDNN::cudnn_all
|
||||
INTERFACE
|
||||
CUDNN::cudnn_adv
|
||||
CUDNN::cudnn_ops
|
||||
CUDNN::cudnn_cnn
|
||||
CUDNN::cudnn_graph
|
||||
CUDNN::cudnn_engines_runtime_compiled
|
||||
CUDNN::cudnn_engines_precompiled
|
||||
CUDNN::cudnn_heuristic
|
||||
)
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDNN_INCLUDE_DIR)
|
|
@ -0,0 +1,12 @@
|
|||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
cudnn_frontend
|
||||
URL ${DEP_URL_cudnn_frontend}
|
||||
URL_HASH SHA1=${DEP_SHA1_cudnn_frontend}
|
||||
)
|
||||
|
||||
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
|
||||
FetchContent_MakeAvailable(cudnn_frontend)
|
|
@ -587,20 +587,16 @@ endif()
|
|||
|
||||
message("Finished fetching external dependencies")
|
||||
|
||||
|
||||
set(onnxruntime_LINK_DIRS )
|
||||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
#TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
if (WIN32)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
|
||||
endif()
|
||||
else()
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
|
||||
endif()
|
||||
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
|
||||
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
|
||||
endif()
|
||||
include(cuDNN)
|
||||
endif()
|
||||
|
||||
if(onnxruntime_USE_SNPE)
|
||||
|
|
|
@ -69,7 +69,7 @@ endif()
|
|||
if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL)
|
||||
# TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP
|
||||
# TODO: provider_bridge_ort.cc should not include nccl.h
|
||||
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
else()
|
||||
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
endif()
|
||||
|
|
|
@ -197,12 +197,16 @@
|
|||
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
|
||||
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
|
||||
else()
|
||||
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
|
||||
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
|
||||
include(cudnn_frontend) # also defines CUDNN::*
|
||||
if (onnxruntime_USE_CUDA_NHWC_OPS)
|
||||
if(CUDNN_MAJOR_VERSION GREATER 8)
|
||||
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
|
||||
else()
|
||||
message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." )
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart
|
||||
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_TRITON_KERNEL)
|
||||
|
|
|
@ -159,7 +159,7 @@
|
|||
if(onnxruntime_CUDA_MINIMAL)
|
||||
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
|
||||
else()
|
||||
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
|
||||
set(trt_link_libs CUDNN::cudnn_all cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
|
||||
endif()
|
||||
file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
|
||||
|
@ -183,9 +183,6 @@
|
|||
endif()
|
||||
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
|
||||
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
endif()
|
||||
|
||||
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
|
||||
set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA)
|
||||
|
|
|
@ -98,11 +98,7 @@ endif()
|
|||
onnxruntime_add_include_to_target(onnxruntime_pybind11_state Python::Module Python::NumPy)
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${pybind11_INCLUDE_DIRS})
|
||||
if(onnxruntime_USE_CUDA)
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
# cudnn_home is optional for Window when cuda and cudnn are installed in the same directory.
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
endif()
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
|
||||
endif()
|
||||
if(onnxruntime_USE_CANN)
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include)
|
||||
|
|
|
@ -145,6 +145,7 @@ set(provider_excluded_files
|
|||
"rnn/rnn_impl.cu"
|
||||
"rnn/rnn_impl.h"
|
||||
"shared_inc/cuda_call.h"
|
||||
"shared_inc/cudnn_fe_call.h"
|
||||
"shared_inc/fpgeneric.h"
|
||||
"cuda_allocator.cc"
|
||||
"cuda_allocator.h"
|
||||
|
@ -171,6 +172,7 @@ set(provider_excluded_files
|
|||
"cuda_utils.cu"
|
||||
"cudnn_common.cc"
|
||||
"cudnn_common.h"
|
||||
"cudnn_fe_call.cc"
|
||||
"cupti_manager.cc"
|
||||
"cupti_manager.h"
|
||||
"fpgeneric.cu"
|
||||
|
|
|
@ -44,9 +44,7 @@ if (onnxruntime_USE_EXTENSIONS)
|
|||
endif()
|
||||
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
|
||||
if (onnxruntime_USE_CUDA)
|
||||
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_ROCM)
|
||||
target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1)
|
||||
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
|
||||
|
|
|
@ -39,10 +39,6 @@ endif()
|
|||
|
||||
target_include_directories(onnxruntime_training PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header} ${MPI_CXX_INCLUDE_DIRS})
|
||||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
target_include_directories(onnxruntime_training PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_NCCL)
|
||||
target_include_directories(onnxruntime_training PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
@ -81,9 +77,6 @@ if (onnxruntime_BUILD_UNIT_TESTS)
|
|||
|
||||
target_include_directories(onnxruntime_training_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header})
|
||||
target_link_libraries(onnxruntime_training_runner PRIVATE nlohmann_json::nlohmann_json)
|
||||
if (onnxruntime_USE_CUDA)
|
||||
target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_NCCL)
|
||||
target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
|
|
|
@ -67,7 +67,7 @@ function(AddTest)
|
|||
if(onnxruntime_USE_CUDA)
|
||||
#XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
|
||||
# otherwise it will impact when CUDA DLLs can be unloaded.
|
||||
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
|
||||
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend)
|
||||
endif()
|
||||
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
|
||||
endif()
|
||||
|
@ -75,7 +75,7 @@ function(AddTest)
|
|||
onnxruntime_add_include_to_target(${_UT_TARGET} date::date flatbuffers::flatbuffers)
|
||||
target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR})
|
||||
if (onnxruntime_USE_CUDA)
|
||||
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
|
||||
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
|
||||
if (onnxruntime_USE_NCCL)
|
||||
target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
@ -392,7 +392,7 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R
|
|||
)
|
||||
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src})
|
||||
|
||||
if (onnxruntime_USE_CUDA_NHWC_OPS)
|
||||
if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8)
|
||||
file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS
|
||||
"${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc"
|
||||
)
|
||||
|
@ -1498,7 +1498,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
|||
list(APPEND custom_op_src_patterns
|
||||
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu"
|
||||
"${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*")
|
||||
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
|
||||
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
|
||||
if (HAS_QSPECTRE)
|
||||
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /Qspectre>")
|
||||
endif()
|
||||
|
|
|
@ -40,6 +40,7 @@ struct CudaContext : public CustomOpContext {
|
|||
bool enable_skip_layer_norm_strict_mode = false;
|
||||
bool prefer_nhwc = false;
|
||||
bool use_tf32 = true;
|
||||
bool fuse_conv_bias = true;
|
||||
|
||||
void Init(const OrtKernelContext& kernel_ctx) {
|
||||
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
|
||||
|
@ -57,6 +58,7 @@ struct CudaContext : public CustomOpContext {
|
|||
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
|
||||
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
|
||||
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
|
||||
fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -38,5 +38,6 @@ struct OrtCUDAProviderOptionsV2 {
|
|||
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
|
||||
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
|
||||
int use_tf32 = 1; // use TF32
|
||||
int fuse_conv_bias = 0; // Enable CUDNN Frontend kernel fusing, results in JIT compiles
|
||||
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
|
||||
};
|
||||
|
|
|
@ -19,4 +19,5 @@ enum CudaResource : int {
|
|||
enable_skip_layer_norm_strict_mode_t,
|
||||
prefer_nhwc_t,
|
||||
use_tf32_t,
|
||||
fuse_conv_bias_t
|
||||
};
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace cuda {
|
|||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Conv<T, true>);
|
||||
onnxruntime::cuda::Conv<T, true>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
|
|
@ -3,17 +3,66 @@
|
|||
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
#include "core/providers/cuda/tensor/slice.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
Status SliceOutUnwantedOutputSection(cudaStream_t stream,
|
||||
const void* input_data, gsl::span<const int64_t> input_dims,
|
||||
void* output_data,
|
||||
const gsl::span<const int64_t>& output_dims,
|
||||
const gsl::span<const int64_t>& starts,
|
||||
const gsl::span<const int64_t>& ends,
|
||||
const gsl::span<const int64_t>& axes,
|
||||
size_t element_size) {
|
||||
SliceOp::PrepareForComputeMetadata compute_metadata(input_dims);
|
||||
|
||||
ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata));
|
||||
|
||||
// As a sanity check, ensure that the slice operator's output shape matches with the expected output shape
|
||||
ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims));
|
||||
|
||||
return ::onnxruntime::cuda::SliceCuda::Impl(stream, input_data, input_dims, output_data,
|
||||
compute_metadata, element_size);
|
||||
}
|
||||
|
||||
static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle,
|
||||
const ::onnxruntime::cuda::CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s,
|
||||
cudnnConvolutionFwdAlgo_t algo,
|
||||
size_t* sz) {
|
||||
return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz);
|
||||
}
|
||||
|
||||
static size_t GetMaxWorkspaceSize(cudnnHandle_t handle,
|
||||
const ::onnxruntime::cuda::CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s,
|
||||
const cudnnConvolutionFwdAlgo_t* algo, int n_algo) {
|
||||
// TODO: get maximum available size from memory arena
|
||||
size_t free, total;
|
||||
CUDA_CALL_THROW(cudaMemGetInfo(&free, &total));
|
||||
// Assuming 10% of fragmentation
|
||||
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
cudnnStatus_t err;
|
||||
size_t sz;
|
||||
err = GetWorkspaceSize(handle, s, algo[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue;
|
||||
max_ws_size = sz;
|
||||
}
|
||||
return max_ws_size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusedConv : public onnxruntime::cuda::Conv<T, false> {
|
||||
class FusedConv : public onnxruntime::cuda::CudaKernel {
|
||||
using CudaT = typename ::onnxruntime::cuda::ToCudaType<T>::MappedType;
|
||||
|
||||
public:
|
||||
using Base = onnxruntime::cuda::Conv<T, false>;
|
||||
FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv<T, false>(info) {
|
||||
FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::CudaKernel(info), conv_attrs_(info) {
|
||||
auto pads_size = conv_attrs_.pads.size();
|
||||
ORT_ENFORCE(pads_size % 2 == 0);
|
||||
std::string activation;
|
||||
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &activation));
|
||||
ORT_THROW_IF_ERROR(MapMode(activation));
|
||||
|
@ -32,66 +81,331 @@ class FusedConv : public onnxruntime::cuda::Conv<T, false> {
|
|||
}
|
||||
}
|
||||
|
||||
Status UpdateState(OpKernelContext* context, bool bias_expected) const {
|
||||
// set X
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& x_shape = X->Shape();
|
||||
const auto x_dims = x_shape.AsShapeVector();
|
||||
s_.x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
|
||||
s_.element_size = X->DataType()->Size();
|
||||
// set W
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const TensorShape& w_shape = W->Shape();
|
||||
auto w_dims = w_shape.AsShapeVector();
|
||||
s_.w_data = reinterpret_cast<const CudaT*>(W->Data<T>());
|
||||
|
||||
// set B
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else {
|
||||
s_.b_data = nullptr;
|
||||
}
|
||||
// set Z
|
||||
if (context->InputCount() >= 4) {
|
||||
const Tensor* Z = context->Input<Tensor>(3);
|
||||
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(),
|
||||
::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
s_.z_data = reinterpret_cast<const CudaT*>(Z->Data<T>());
|
||||
} else {
|
||||
s_.z_data = nullptr;
|
||||
}
|
||||
bool input_dims_changed = (s_.last_x_dims != x_dims);
|
||||
bool w_dims_changed = (s_.last_w_dims != w_dims);
|
||||
if (input_dims_changed || w_dims_changed) {
|
||||
if (input_dims_changed)
|
||||
s_.last_x_dims = gsl::make_span(x_dims);
|
||||
|
||||
if (w_dims_changed) {
|
||||
s_.last_w_dims = gsl::make_span(w_dims);
|
||||
s_.cached_benchmark_results.clear();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape()));
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
ConvPadVector pads(conv_attrs_.pads);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_rank * 2, 0);
|
||||
}
|
||||
TensorShapeVector dilations(conv_attrs_.dilations);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_rank, 1);
|
||||
}
|
||||
TensorShapeVector strides(conv_attrs_.strides);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_rank, 1);
|
||||
}
|
||||
|
||||
TensorShapeVector y_dims;
|
||||
y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C'
|
||||
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t M = W->Shape()[0];
|
||||
y_dims.insert(y_dims.begin(), {N, M});
|
||||
|
||||
bool post_slicing_required = false;
|
||||
TensorShapeVector slice_starts;
|
||||
slice_starts.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_ends;
|
||||
slice_ends.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_axes;
|
||||
slice_axes.reserve(kernel_rank);
|
||||
|
||||
constexpr size_t spatial_dim_start = 2;
|
||||
const size_t spatial_dim_end = spatial_dim_start + kernel_rank;
|
||||
TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end);
|
||||
|
||||
TensorShapeVector y_dims_with_adjusted_pads(y_dims);
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape,
|
||||
strides, dilations, pads, y_dims,
|
||||
y_dims_with_adjusted_pads, post_slicing_required,
|
||||
slice_starts, slice_ends, slice_axes));
|
||||
|
||||
ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size());
|
||||
s_.y_dims = gsl::make_span(y_dims);
|
||||
s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads;
|
||||
s_.post_slicing_required = post_slicing_required;
|
||||
s_.slice_starts = slice_starts;
|
||||
s_.slice_ends = slice_ends;
|
||||
s_.slice_axes = slice_axes;
|
||||
|
||||
s_.Y = context->Output(0, TensorShape(s_.y_dims));
|
||||
if (post_slicing_required) {
|
||||
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
|
||||
s_.memory_for_cudnn_conv_results =
|
||||
GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size,
|
||||
context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
// No post slicing needed. Fill the output tensor's buffer directly.
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
|
||||
const CUDAExecutionProvider* cuda_ep =
|
||||
static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
|
||||
TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()};
|
||||
TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads;
|
||||
if (kernel_rank < 2) {
|
||||
// TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D]
|
||||
// especially for EXHAUSTIVE algo search which may result in a better algo selection.
|
||||
// ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to
|
||||
// inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape
|
||||
// [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad.
|
||||
// PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems
|
||||
// to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT.
|
||||
// See PR #7348 and #7702 for more context.
|
||||
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
|
||||
x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1);
|
||||
y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1);
|
||||
w_dims.insert(w_dims.begin() + 2, 1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.begin(), 0);
|
||||
kernel_shape.insert(kernel_shape.begin(), 1);
|
||||
strides.insert(strides.begin(), 1);
|
||||
dilations.insert(dilations.begin(), 1);
|
||||
} else {
|
||||
x_dims_cudnn.push_back(1);
|
||||
y_dims_cudnn.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
strides.push_back(1);
|
||||
dilations.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (w_dims_changed) {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
}
|
||||
|
||||
// We must delay returning early until here so that the weight dims have been cached properly
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>(), UseTF32()));
|
||||
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
const auto& b_shape = B->Shape();
|
||||
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = b_shape[0];
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
// s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else if (bias_expected) {
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = w_dims[0];
|
||||
auto malloc_size = b_dims[1] * sizeof(CudaT);
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
|
||||
if (s_.b_zero) {
|
||||
CUDA_CALL_THROW(cudaFree(s_.b_zero));
|
||||
s_.b_zero = nullptr;
|
||||
}
|
||||
CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size));
|
||||
CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
|
||||
}
|
||||
|
||||
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
|
||||
// set math type to tensor core before algorithm search
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
if (!UseTF32()) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnConvolutionFwdAlgoPerf_t perf;
|
||||
int algo_count = 1;
|
||||
int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
|
||||
ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ",
|
||||
cudnn_conv_algo);
|
||||
switch (cudnn_conv_algo) {
|
||||
case 0: {
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context),
|
||||
s_, kAllAlgos, num_algos)
|
||||
: ::onnxruntime::cuda::AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
s_.y_data,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf,
|
||||
algo_search_workspace.get(),
|
||||
max_ws_size));
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.w_desc,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf));
|
||||
break;
|
||||
|
||||
default:
|
||||
perf.algo = kDefaultConvAlgo;
|
||||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
|
||||
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (std::is_same<T, float>::value && !UseTF32()) {
|
||||
perf.mathType = CUDNN_FMA_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
}
|
||||
s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType});
|
||||
}
|
||||
const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType));
|
||||
s_.algo = perf.algo;
|
||||
s_.workspace_bytes = perf.memory;
|
||||
} else {
|
||||
// set Y
|
||||
s_.Y = context->Output(0, s_.y_dims);
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (s_.post_slicing_required) {
|
||||
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(
|
||||
TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
std::lock_guard<OrtMutex> lock(Base::s_.mutex);
|
||||
std::lock_guard<OrtMutex> lock(s_.mutex);
|
||||
auto cudnnHandle = this->GetCudnnHandle(context);
|
||||
ORT_RETURN_IF_ERROR(Base::UpdateState(context, true));
|
||||
if (Base::s_.Y->Shape().Size() == 0) {
|
||||
ORT_RETURN_IF_ERROR(UpdateState(context, true));
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
bool has_z = nullptr != Base::s_.z_data;
|
||||
bool has_b = nullptr != Base::s_.b_data;
|
||||
typedef typename onnxruntime::cuda::ToCudaType<T>::MappedType CudaT;
|
||||
bool has_z = nullptr != s_.z_data;
|
||||
bool has_b = nullptr != s_.b_data;
|
||||
const auto alpha = onnxruntime::cuda::Consts<CudaT>::One;
|
||||
const auto beta = onnxruntime::cuda::Consts<CudaT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = Base::GetWorkSpace(context->GetComputeStream());
|
||||
IAllocatorUniquePtr<void> workspace = GetWorkSpace(context->GetComputeStream());
|
||||
auto cudnn_status = cudnnConvolutionBiasActivationForward(cudnnHandle,
|
||||
&alpha,
|
||||
Base::s_.x_tensor,
|
||||
Base::s_.x_data,
|
||||
Base::s_.w_desc,
|
||||
Base::s_.w_data,
|
||||
Base::s_.conv_desc,
|
||||
Base::s_.algo,
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.algo,
|
||||
workspace.get(),
|
||||
Base::s_.workspace_bytes,
|
||||
s_.workspace_bytes,
|
||||
has_z ? &alpha : &beta,
|
||||
has_z ? Base::s_.z_tensor : Base::s_.y_tensor,
|
||||
has_z ? Base::s_.z_data : Base::s_.y_data,
|
||||
Base::s_.b_tensor,
|
||||
has_b ? Base::s_.b_data : Base::s_.b_zero,
|
||||
has_z ? s_.z_tensor : s_.y_tensor,
|
||||
has_z ? s_.z_data : s_.y_data,
|
||||
s_.b_tensor,
|
||||
has_b ? s_.b_data : s_.b_zero,
|
||||
activation_desc_,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data);
|
||||
s_.y_tensor,
|
||||
s_.y_data);
|
||||
if (CUDNN_STATUS_SUCCESS != cudnn_status) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnnHandle,
|
||||
&alpha,
|
||||
Base::s_.x_tensor,
|
||||
Base::s_.x_data,
|
||||
Base::s_.w_desc,
|
||||
Base::s_.w_data,
|
||||
Base::s_.conv_desc,
|
||||
Base::s_.algo,
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.algo,
|
||||
workspace.get(),
|
||||
Base::s_.workspace_bytes,
|
||||
s_.workspace_bytes,
|
||||
&beta,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data));
|
||||
s_.y_tensor,
|
||||
s_.y_data));
|
||||
if (has_b) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.b_tensor, Base::s_.b_data,
|
||||
&alpha, Base::s_.y_tensor, Base::s_.y_data));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.b_tensor, s_.b_data,
|
||||
&alpha, s_.y_tensor, s_.y_data));
|
||||
}
|
||||
if (has_z) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.z_tensor, Base::s_.z_data,
|
||||
&alpha, Base::s_.y_tensor, Base::s_.y_data));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.z_tensor, s_.z_data,
|
||||
&alpha, s_.y_tensor, s_.y_data));
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, Base::s_.y_tensor,
|
||||
Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, s_.y_tensor,
|
||||
s_.y_data, &beta, s_.y_tensor, s_.y_data));
|
||||
}
|
||||
if (Base::s_.post_slicing_required) {
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
|
||||
this->Stream(context), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
|
||||
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
|
||||
if (s_.post_slicing_required) {
|
||||
ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(
|
||||
this->Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, s_.Y->MutableDataRaw(),
|
||||
s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -107,6 +421,25 @@ class FusedConv : public onnxruntime::cuda::Conv<T, false> {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline IAllocatorUniquePtr<void> GetWorkSpace(onnxruntime::Stream* stream) const {
|
||||
return GetScratchBuffer<void>(s_.workspace_bytes, stream);
|
||||
}
|
||||
|
||||
ConvAttributes conv_attrs_;
|
||||
mutable ::onnxruntime::cuda::CudnnConvState<cudnnConvolutionFwdAlgoPerf_t> s_;
|
||||
constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
constexpr static cudnnConvolutionFwdAlgo_t kAllAlgos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
|
||||
cudnnActivationMode_t activation_mode_;
|
||||
cudnnActivationDescriptor_t activation_desc_ = nullptr;
|
||||
};
|
||||
|
@ -122,4 +455,4 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
|
@ -58,16 +58,11 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
|
|||
}
|
||||
|
||||
bool ConvFusionDataTypeCheck(const Node& conv_node) {
|
||||
// TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation
|
||||
// TODO(hasesh): The CPU EP only supports float type for the Conv+Activation
|
||||
// and the Conv+Add+Relu fusions.
|
||||
// Assess the support level for the other compatible EPs and if they also
|
||||
// only support float, remove the EP check altogether.
|
||||
const std::string_view node_ep = conv_node.GetExecutionProviderType();
|
||||
if (node_ep == kCudaExecutionProvider) {
|
||||
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (node_ep == kCpuExecutionProvider) {
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) &&
|
||||
|
@ -120,7 +115,9 @@ class ConvActivationSelector : public NodeSelector {
|
|||
}
|
||||
|
||||
// check EP type and activation
|
||||
if (node_ep == kCudaExecutionProvider || node_ep == kRocmExecutionProvider) {
|
||||
if (node_ep == kCudaExecutionProvider) {
|
||||
return std::nullopt;
|
||||
} else if (node_ep == kRocmExecutionProvider) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -142,43 +139,6 @@ class ConvActivationSelector : public NodeSelector {
|
|||
}
|
||||
};
|
||||
|
||||
class ConvAddRelu : public NodeSelector {
|
||||
public:
|
||||
ConvAddRelu() = default;
|
||||
|
||||
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override {
|
||||
const std::string_view node_ep = node.GetExecutionProviderType();
|
||||
// only for CUDA EP
|
||||
if (node_ep != kCudaExecutionProvider) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ConvFusionDataTypeCheck(node)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
|
||||
if (!add_node ||
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) ||
|
||||
add_node->GetExecutionProviderType() != node_ep) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto* relu_node = GetLoneConsumerNode(graph_viewer, *add_node);
|
||||
if (!relu_node ||
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(*relu_node, "Relu", {6, 13, 14}) ||
|
||||
relu_node->GetExecutionProviderType() != node_ep) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
NodesToOptimizeIndicesBuilder builder{};
|
||||
builder.target_node = node.Index();
|
||||
builder.output_nodes = {add_node->Index(),
|
||||
relu_node->Index()};
|
||||
return builder.Build();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace selectors
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
|
@ -304,22 +264,9 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void RegisterConvAddReluFusionRules(SelectorActionRegistry& registry) {
|
||||
const auto name = "ConvAddRelu";
|
||||
auto action = std::make_unique<actions::FuseConvAddRelu>();
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
auto selector = std::make_unique<selectors::ConvAddRelu>();
|
||||
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},
|
||||
std::move(selector), std::move(action));
|
||||
#else
|
||||
registry.RegisterAction(name, std::move(action));
|
||||
#endif
|
||||
}
|
||||
|
||||
SelectorActionRegistry CreateSelectorActionRegistry() {
|
||||
SelectorActionRegistry registry{};
|
||||
RegisterConvActivationFusionRules(registry);
|
||||
RegisterConvAddReluFusionRules(registry);
|
||||
return registry;
|
||||
}
|
||||
|
||||
|
|
|
@ -282,6 +282,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
onnxruntime::kCudaExecutionProvider,
|
||||
onnxruntime::kRocmExecutionProvider,
|
||||
onnxruntime::kDmlExecutionProvider};
|
||||
const InlinedHashSet<std::string_view> cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
|
||||
onnxruntime::kRocmExecutionProvider,
|
||||
onnxruntime::kAclExecutionProvider,
|
||||
onnxruntime::kArmNNExecutionProvider,
|
||||
onnxruntime::kJsExecutionProvider};
|
||||
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
|
||||
onnxruntime::kCudaExecutionProvider,
|
||||
onnxruntime::kRocmExecutionProvider,
|
||||
|
@ -318,7 +323,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_eps));
|
||||
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_ep));
|
||||
|
||||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_js_eps));
|
||||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
|
|
@ -83,6 +83,14 @@ inline AutoPadType StringToAutoPadType(const std::string& str) {
|
|||
|
||||
// helper function
|
||||
|
||||
constexpr inline int64_t ComputeOutputShape(const int64_t in_dim,
|
||||
const int64_t stride, const int64_t kernel, const int64_t dilation,
|
||||
const int64_t pad_head, const int64_t pad_tail) {
|
||||
const SafeInt<int64_t> dkernel = SafeInt<int64_t>(dilation) * (kernel - 1) + 1;
|
||||
int64_t dkernel_value = SafeInt<int64_t>(in_dim) + pad_head + pad_tail - dkernel;
|
||||
return static_cast<int64_t>(static_cast<double>(dkernel_value) / stride + 1);
|
||||
}
|
||||
|
||||
inline Status ComputePad(const int64_t in_dim,
|
||||
const int64_t stride, const int64_t kernel, const int64_t dilation,
|
||||
AutoPadType pad_type,
|
||||
|
@ -106,6 +114,15 @@ inline Status ComputePad(const int64_t in_dim,
|
|||
// is retained as is
|
||||
SafeInt<int64_t> legacy_target_size = (SafeInt<int64_t>(in_dim) + stride - 1) / stride;
|
||||
SafeInt<int64_t> pad_needed = (legacy_target_size - 1) * stride + kernel - in_dim;
|
||||
// out_dim = floor((in_dim + 2p - k) / s) + 1
|
||||
// => if (in_dim + 2p - k) is not divisible by s we can remove the floor with following equation:
|
||||
// out_dim + eps = ((in_dim + 2p - k) / s) + 1 ;where eps is in [0.0, 1.0]
|
||||
// therefore in edge cases padding can lower calculated above than it should be
|
||||
SafeInt<int64_t> actual_out_size = ComputeOutputShape(in_dim, stride, kernel, /*dilation*/ 1,
|
||||
pad_needed, pad_needed);
|
||||
if (actual_out_size < legacy_target_size) {
|
||||
pad_needed += 1;
|
||||
}
|
||||
// make sure padding is symmetric
|
||||
if (force_symmetric_auto_padding) {
|
||||
// Inlining math::roundUpPow2() from util/math.h to avoid bringing in the transitive dependencies.
|
||||
|
@ -126,14 +143,6 @@ inline Status ComputePad(const int64_t in_dim,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
constexpr inline int64_t ComputeOutputShape(const int64_t in_dim,
|
||||
const int64_t stride, const int64_t kernel, const int64_t dilation,
|
||||
const int64_t pad_head, const int64_t pad_tail) {
|
||||
const SafeInt<int64_t> dkernel = SafeInt<int64_t>(dilation) * (kernel - 1) + 1;
|
||||
int64_t dkernel_value = SafeInt<int64_t>(in_dim) + pad_head + pad_tail - dkernel;
|
||||
return static_cast<int64_t>(static_cast<double>(dkernel_value) / stride + 1);
|
||||
}
|
||||
|
||||
inline Status ComputePadAndOutputShape(const int64_t in_dim,
|
||||
const int64_t stride, const int64_t kernel, const int64_t dilation,
|
||||
AutoPadType pad_type,
|
||||
|
|
|
@ -34,7 +34,6 @@ const char* CudaErrString<cudaError_t>(cudaError_t x) {
|
|||
template <>
|
||||
const char* CudaErrString<cublasStatus_t>(cublasStatus_t e) {
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
switch (e) {
|
||||
CASE_ENUM_TO_STR(CUBLAS_STATUS_SUCCESS);
|
||||
CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED);
|
||||
|
@ -87,9 +86,15 @@ const char* CudaErrString<ncclResult_t>(ncclResult_t e) {
|
|||
}
|
||||
#endif
|
||||
|
||||
template <typename ERRTYPE, bool THRW>
|
||||
template <typename ERRTYPE>
|
||||
int GetErrorCode(ERRTYPE err) {
|
||||
return static_cast<int>(err);
|
||||
}
|
||||
|
||||
template <typename ERRTYPE, bool THRW, typename SUCCTYPE>
|
||||
std::conditional_t<THRW, void, Status> CudaCall(
|
||||
ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) {
|
||||
ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg,
|
||||
const char* file, const int line) {
|
||||
if (retCode != successCode) {
|
||||
try {
|
||||
#ifdef _WIN32
|
||||
|
@ -108,7 +113,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
|
|||
cudaGetLastError(); // clear last CUDA error
|
||||
static char str[1024];
|
||||
snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s",
|
||||
libName, (int)retCode, CudaErrString(retCode), currentCudaDevice,
|
||||
libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice,
|
||||
hostname,
|
||||
file, line, exprString, msg);
|
||||
if constexpr (THRW) {
|
||||
|
@ -118,7 +123,8 @@ std::conditional_t<THRW, void, Status> CudaCall(
|
|||
LOGS_DEFAULT(ERROR) << str;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str);
|
||||
}
|
||||
} catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error
|
||||
} catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction,
|
||||
// so we'd never get to see the error
|
||||
if constexpr (THRW) {
|
||||
ORT_THROW(e.what());
|
||||
} else {
|
||||
|
|
|
@ -180,6 +180,7 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de
|
|||
|
||||
CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
|
||||
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
|
||||
LOGS_DEFAULT(INFO) << "cuDNN version: " << cudnnGetVersion();
|
||||
#endif
|
||||
cuda_graph_.SetStream(stream);
|
||||
}
|
||||
|
@ -2469,6 +2470,19 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool NhwcConvNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger,
|
||||
[[maybe_unused]] const GraphViewer& graph_viewer,
|
||||
[[maybe_unused]] const bool prefer_nhwc) {
|
||||
// NHWC implementation doesn't handle W in NHWC layout if it's not an initializer
|
||||
if (!graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), true)) {
|
||||
LOGS(logger, WARNING) << "Dropping the NhwcConv node: " << node.Name()
|
||||
<< " to CPU because the Cuda EP requires W as initializer for NHWC operation.";
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
|
||||
const auto& node_attributes = node.GetAttributes();
|
||||
// Check attributes
|
||||
|
@ -2539,6 +2553,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
} else if ("Cast" == node.OpType()) {
|
||||
not_supported = CastNeedFallbackToCPU(node);
|
||||
// cast is not compute heavy, and may be placed outside
|
||||
} else if ("NhwcConv" == node.OpType()) {
|
||||
not_supported = NhwcConvNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
|
||||
force_inside = !not_supported;
|
||||
}
|
||||
|
||||
if (!force_inside && not_supported) {
|
||||
|
|
|
@ -82,6 +82,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; }
|
||||
bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; }
|
||||
bool IsNHWCPreferred() const { return info_.prefer_nhwc; }
|
||||
bool IsFuseConvBias() const { return info_.fuse_conv_bias; }
|
||||
bool UseTF32() const { return info_.use_tf32; }
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
|
|
@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s
|
|||
constexpr const char* kPreferNHWCMode = "prefer_nhwc";
|
||||
constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream";
|
||||
constexpr const char* kUseTF32 = "use_tf32";
|
||||
constexpr const char* kFuseConvBias = "fuse_conv_bias";
|
||||
constexpr const char* kSdpaKernel = "sdpa_kernel";
|
||||
|
||||
} // namespace provider_option_names
|
||||
|
@ -119,6 +120,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
.AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kFuseConvBias, info.fuse_conv_bias)
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kTunableOpEnable,
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
|
@ -173,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
|
|||
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
|
||||
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
|
||||
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
|
||||
{cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)},
|
||||
};
|
||||
|
||||
return options;
|
||||
|
@ -195,6 +198,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
|
|||
{cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)},
|
||||
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
|
||||
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
|
||||
{cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)},
|
||||
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
|
||||
};
|
||||
|
||||
|
|
|
@ -78,6 +78,7 @@ struct CUDAExecutionProviderInfo {
|
|||
|
||||
// By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices.
|
||||
bool use_tf32{true};
|
||||
bool fuse_conv_bias{false};
|
||||
|
||||
int sdpa_kernel{0};
|
||||
|
||||
|
@ -107,7 +108,8 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> {
|
|||
(static_cast<size_t>(info.enable_skip_layer_norm_strict_mode) << 27) ^
|
||||
(static_cast<size_t>(info.prefer_nhwc) << 28) ^
|
||||
(static_cast<size_t>(info.use_ep_level_unified_stream) << 29) ^
|
||||
(static_cast<size_t>(info.use_tf32) << 30);
|
||||
(static_cast<size_t>(info.use_tf32) << 30) ^
|
||||
(static_cast<size_t>(info.fuse_conv_bias) << 31);
|
||||
onnxruntime::HashCombine(data, value);
|
||||
|
||||
onnxruntime::HashCombine(info.gpu_mem_limit, value);
|
||||
|
|
|
@ -219,6 +219,7 @@ struct CUDA_Provider : Provider {
|
|||
info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0;
|
||||
info.enable_cuda_graph = params->enable_cuda_graph != 0;
|
||||
info.prefer_nhwc = params->prefer_nhwc;
|
||||
info.fuse_conv_bias = params->fuse_conv_bias;
|
||||
info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0;
|
||||
info.tunable_op.enable = params->tunable_op_enable;
|
||||
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
|
||||
|
@ -262,6 +263,7 @@ struct CUDA_Provider : Provider {
|
|||
cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream;
|
||||
cuda_options.use_tf32 = internal_options.use_tf32;
|
||||
cuda_options.sdpa_kernel = internal_options.sdpa_kernel;
|
||||
cuda_options.fuse_conv_bias = internal_options.fuse_conv_bias;
|
||||
}
|
||||
|
||||
ProviderOptions GetProviderOptions(const void* provider_options) override {
|
||||
|
|
|
@ -215,6 +215,9 @@ void* CudaStream::GetResource(int version, int id) const {
|
|||
case CudaResource::prefer_nhwc_t:
|
||||
return reinterpret_cast<void*>(ep_info_.prefer_nhwc);
|
||||
break;
|
||||
case CudaResource::fuse_conv_bias_t:
|
||||
return reinterpret_cast<void*>(ep_info_.fuse_conv_bias);
|
||||
break;
|
||||
case CudaResource::use_tf32_t:
|
||||
return reinterpret_cast<void*>(ep_info_.use_tf32);
|
||||
break;
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/common/inlined_containers.h"
|
||||
|
@ -60,7 +62,25 @@ Status CudnnTensor::Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dat
|
|||
dims[1] = gsl::narrow_cast<int>(input_dims[rank - 1]);
|
||||
strides[1] = gsl::narrow_cast<int>(pitches[rank - 1]);
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data()));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank),
|
||||
dims.data(), strides.data()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CudnnTensor::Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dataType,
|
||||
gsl::span<const int64_t> input_strides) {
|
||||
ORT_RETURN_IF_ERROR(CreateTensorIfNeeded());
|
||||
|
||||
int rank = gsl::narrow_cast<int>(input_dims.size());
|
||||
InlinedVector<int, kTensorShapeSmallBufferElementsSize> dims(rank);
|
||||
InlinedVector<int, kTensorShapeSmallBufferElementsSize> strides(rank);
|
||||
|
||||
for (int i = 0; i < rank; i++) {
|
||||
dims[i] = gsl::narrow_cast<int>(input_dims[i]);
|
||||
strides[i] = gsl::narrow_cast<int>(input_strides[i]);
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(
|
||||
cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -100,7 +120,8 @@ Status CudnnDataTensor::Set(cudnnDataType_t dataType,
|
|||
const int32_t* seq_lengths) {
|
||||
ORT_RETURN_IF_ERROR(CreateTensorIfNeeded());
|
||||
|
||||
// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
|
||||
// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED,
|
||||
// so that it will auto fill 0 for the shorter sequences
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
|
||||
float padding_fill = 0.0f;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout,
|
||||
|
@ -238,6 +259,91 @@ const Float8E5M2 Consts<Float8E5M2>::One = Float8E5M2(1.0f, true);
|
|||
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> generateStrides(const std::vector<int64_t>& shape, bool channels_last) {
|
||||
// For INT8x4 and INT8x32 we still compute standard strides here to input
|
||||
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
|
||||
std::vector<int64_t> strides(shape.size());
|
||||
int64_t nbDims = strides.size();
|
||||
if (nbDims <= 1) {
|
||||
strides[0] = 1;
|
||||
return strides;
|
||||
}
|
||||
if (channels_last) {
|
||||
// Here we assume that the format is CUDNN_TENSOR_NHWC
|
||||
strides[1] = 1;
|
||||
strides[nbDims - 1] = strides[1] * shape[1];
|
||||
for (int64_t d = nbDims - 2; d >= 2; d--) {
|
||||
strides[d] = strides[d + 1] * shape[d + 1];
|
||||
}
|
||||
strides[0] = strides[2] * shape[2];
|
||||
} else {
|
||||
strides[nbDims - 1] = 1;
|
||||
for (int64_t d = nbDims - 2; d >= 0; d--) {
|
||||
strides[d] = strides[d + 1] * shape[d + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
CudnnFeTensor::CudnnFeTensor(const onnxruntime::TensorShapeVector& shape,
|
||||
const std::string& name,
|
||||
std::optional<cudnn_frontend::DataType_t> dtype,
|
||||
const bool nhwc) {
|
||||
std::vector<int64_t> shape_vec;
|
||||
if (shape.size() == 1) {
|
||||
shape_vec = {1, shape[0], 1, 1};
|
||||
} else if (shape.size() >= 4) {
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
shape_vec.push_back(shape[i]);
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("Invalid tensor shape size, tensor name: ", name, ", shape size: ", shape.size());
|
||||
}
|
||||
auto strides = generateStrides(shape_vec, nhwc);
|
||||
|
||||
if (dtype.has_value()) {
|
||||
tensor_ = cudnn_frontend::graph::Tensor_attributes()
|
||||
.set_name(name)
|
||||
.set_dim(shape_vec)
|
||||
.set_stride(strides)
|
||||
.set_data_type(dtype.value());
|
||||
} else {
|
||||
tensor_ = cudnn_frontend::graph::Tensor_attributes().set_name(name).set_dim(shape_vec).set_stride(strides);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() {
|
||||
return cudnn_frontend::DataType_t::NOT_SET;
|
||||
}
|
||||
|
||||
template <>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType<float>() {
|
||||
return cudnn_frontend::DataType_t::FLOAT;
|
||||
}
|
||||
|
||||
template <>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType<half>() {
|
||||
return cudnn_frontend::DataType_t::HALF;
|
||||
}
|
||||
|
||||
template <>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType<double>() {
|
||||
return cudnn_frontend::DataType_t::DOUBLE;
|
||||
}
|
||||
|
||||
template <>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType<int8_t>() {
|
||||
return cudnn_frontend::DataType_t::INT8;
|
||||
}
|
||||
|
||||
template <>
|
||||
cudnn_frontend::DataType_t CudnnFeTensor::GetDataType<uint8_t>() {
|
||||
return cudnn_frontend::DataType_t::UINT8;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
@ -5,12 +5,22 @@
|
|||
#pragma once
|
||||
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#if !defined(__CUDACC__)
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define CUDNN_FE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_FE_CALL(expr))
|
||||
|
||||
class CudnnTensor final {
|
||||
public:
|
||||
CudnnTensor();
|
||||
|
@ -19,6 +29,7 @@ class CudnnTensor final {
|
|||
|
||||
Status Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dataType, bool is_nhwc = false);
|
||||
Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode);
|
||||
Status Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dataType, gsl::span<const int64_t> input_strides);
|
||||
// Set 4D tensor format (for NHWC)
|
||||
Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w);
|
||||
|
||||
|
@ -139,7 +150,8 @@ struct Consts<BFloat16> {
|
|||
inline double ClampCudnnBatchNormEpsilon(double epsilon) {
|
||||
if (epsilon < CUDNN_BN_MIN_EPSILON) {
|
||||
if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON)
|
||||
LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON";
|
||||
LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. "
|
||||
<< "Setting it to CUDNN_BN_MIN_EPSILON";
|
||||
return CUDNN_BN_MIN_EPSILON;
|
||||
}
|
||||
return epsilon;
|
||||
|
@ -258,6 +270,23 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc,
|
|||
return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA);
|
||||
}
|
||||
|
||||
std::vector<int64_t> generateStrides(const std::vector<int64_t>& shape, bool channels_last);
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
class CudnnFeTensor final {
|
||||
public:
|
||||
CudnnFeTensor(const onnxruntime::TensorShapeVector& shape, const std::string& name,
|
||||
std::optional<cudnn_frontend::DataType_t> dtype, const bool nhwc);
|
||||
|
||||
template <typename T>
|
||||
static cudnn_frontend::DataType_t GetDataType();
|
||||
cudnn_frontend::graph::Tensor_attributes Get() { return tensor_; }
|
||||
|
||||
private:
|
||||
cudnn_frontend::graph::Tensor_attributes tensor_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include <core/platform/env.h>
|
||||
#if !defined(__CUDACC__)
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
#ifdef _WIN32
|
||||
#else // POSIX
|
||||
#include <unistd.h>
|
||||
#include <string.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using namespace common;
|
||||
|
||||
template <typename ERRTYPE>
|
||||
const char* CudaErrString(ERRTYPE) {
|
||||
ORT_NOT_IMPLEMENTED();
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
#define CASE_ENUM_TO_STR_CUDNN_FE(x) \
|
||||
case cudnn_frontend::error_code_t::x: \
|
||||
return #x
|
||||
|
||||
template <>
|
||||
const char* CudaErrString<cudnn_frontend::error_t>(cudnn_frontend::error_t x) {
|
||||
cudaDeviceSynchronize();
|
||||
LOGS_DEFAULT(ERROR) << x.get_message();
|
||||
switch (x.get_code()) {
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(OK);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(ATTRIBUTE_NOT_SET);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(SHAPE_DEDUCTION_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(INVALID_TENSOR_NAME);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(INVALID_VARIANT_PACK);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_NOT_SUPPORTED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_PLAN_CREATION_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(HEURISTIC_QUERY_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(UNSUPPORTED_GRAPH_FORMAT);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(CUDA_API_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(CUDNN_BACKEND_API_FAILED);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(INVALID_CUDA_DEVICE);
|
||||
CASE_ENUM_TO_STR_CUDNN_FE(HANDLE_ERROR);
|
||||
default:
|
||||
return "Unknown CUDNN_FRONTEND error status";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ERRTYPE>
|
||||
int GetErrorCode(ERRTYPE err) {
|
||||
return static_cast<int>(err);
|
||||
}
|
||||
|
||||
template <>
|
||||
int GetErrorCode(cudnn_frontend::error_t err) {
|
||||
return static_cast<int>(err.get_code());
|
||||
}
|
||||
|
||||
template <typename ERRTYPE, bool THRW, typename SUCCTYPE>
|
||||
std::conditional_t<THRW, void, Status> CudaCall(
|
||||
ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg,
|
||||
const char* file, const int line) {
|
||||
if (retCode != successCode) {
|
||||
try {
|
||||
#ifdef _WIN32
|
||||
std::string hostname_str = GetEnvironmentVar("COMPUTERNAME");
|
||||
if (hostname_str.empty()) {
|
||||
hostname_str = "?";
|
||||
}
|
||||
const char* hostname = hostname_str.c_str();
|
||||
#else
|
||||
char hostname[HOST_NAME_MAX];
|
||||
if (gethostname(hostname, HOST_NAME_MAX) != 0)
|
||||
strcpy(hostname, "?");
|
||||
#endif
|
||||
int currentCudaDevice;
|
||||
cudaGetDevice(¤tCudaDevice);
|
||||
cudaGetLastError(); // clear last CUDA error
|
||||
static char str[1024];
|
||||
snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s",
|
||||
libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice,
|
||||
hostname,
|
||||
file, line, exprString, msg);
|
||||
if constexpr (THRW) {
|
||||
// throw an exception with the error info
|
||||
ORT_THROW(str);
|
||||
} else {
|
||||
LOGS_DEFAULT(ERROR) << str;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str);
|
||||
}
|
||||
} catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction,
|
||||
// so we'd never get to see the error
|
||||
if constexpr (THRW) {
|
||||
ORT_THROW(e.what());
|
||||
} else {
|
||||
LOGS_DEFAULT(ERROR) << e.what();
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!THRW) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
template Status CudaCall<cudnn_frontend::error_t, false, cudnn_frontend::error_code_t>(
|
||||
cudnn_frontend::error_t retCode, const char* exprString, const char* libName,
|
||||
cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line);
|
||||
template void CudaCall<cudnn_frontend::error_t, true, cudnn_frontend::error_code_t>(
|
||||
cudnn_frontend::error_t retCode, const char* exprString, const char* libName,
|
||||
cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line);
|
||||
|
||||
#endif
|
||||
} // namespace onnxruntime
|
|
@ -3,14 +3,20 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
#include "core/common/span_utils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/tensor/slice.h"
|
||||
#include "core/providers/cuda/tensor/transpose.h"
|
||||
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
|
||||
|
||||
#if CUDNN_MAJOR < 9
|
||||
// if compiled with cuDNN 8 we want to use the legacy cuDNN API
|
||||
#include "conv_8.h"
|
||||
#endif
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
|
@ -43,58 +49,7 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true)
|
|||
REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true)
|
||||
#endif
|
||||
|
||||
template <typename T, bool NHWC>
|
||||
const cudnnConvolutionFwdAlgo_t Conv<T, NHWC>::kAllAlgos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) {
|
||||
return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz);
|
||||
}
|
||||
|
||||
size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s,
|
||||
const cudnnConvolutionFwdAlgo_t* algo, int n_algo) {
|
||||
// TODO: get maximum available size from memory arena
|
||||
size_t free, total;
|
||||
CUDA_CALL_THROW(cudaMemGetInfo(&free, &total));
|
||||
// Assuming 10% of fragmentation
|
||||
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
cudnnStatus_t err;
|
||||
size_t sz;
|
||||
err = GetWorkspaceSize(handle, s, algo[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue;
|
||||
max_ws_size = sz;
|
||||
}
|
||||
return max_ws_size;
|
||||
}
|
||||
|
||||
Status SliceOutUnwantedOutputSection(cudaStream_t stream,
|
||||
const void* input_data, gsl::span<const int64_t> input_dims,
|
||||
void* output_data,
|
||||
const gsl::span<const int64_t>& output_dims,
|
||||
const gsl::span<const int64_t>& starts,
|
||||
const gsl::span<const int64_t>& ends,
|
||||
const gsl::span<const int64_t>& axes,
|
||||
size_t element_size) {
|
||||
SliceOp::PrepareForComputeMetadata compute_metadata(input_dims);
|
||||
|
||||
ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata));
|
||||
|
||||
// As a sanity check, ensure that the slice operator's output shape matches with the expected output shape
|
||||
ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims));
|
||||
|
||||
return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size);
|
||||
}
|
||||
|
||||
// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW
|
||||
template <typename T, bool NHWC>
|
||||
Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
|
||||
|
@ -104,14 +59,20 @@ Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
|
|||
if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W
|
||||
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
|
||||
auto orig_shape = tensor.Shape();
|
||||
auto shape_size = orig_shape.GetDims().size();
|
||||
|
||||
InlinedVector<size_t> perm{0, 2, 3, 1};
|
||||
gsl::span<size_t> permutation(perm.data(), 4);
|
||||
TensorShapeVector new_dims{orig_shape[0],
|
||||
orig_shape[2],
|
||||
orig_shape[3],
|
||||
orig_shape[1]};
|
||||
W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc));
|
||||
InlinedVector<size_t, 5> perm;
|
||||
perm.push_back(0);
|
||||
for (size_t i = 2; i < shape_size; i++) perm.push_back(i);
|
||||
perm.push_back(1);
|
||||
gsl::span<size_t> permutation(perm.data(), shape_size);
|
||||
|
||||
TensorShapeVector nhwc_dims;
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
nhwc_dims.push_back(orig_shape[perm[i]]);
|
||||
}
|
||||
|
||||
W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc));
|
||||
|
||||
auto status = cuda::Transpose::DoTranspose(GetDeviceProp(),
|
||||
DefaultCudaStream(),
|
||||
|
@ -122,6 +83,8 @@ Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
|
|||
}
|
||||
CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream()));
|
||||
is_packed = true;
|
||||
} else {
|
||||
W_already_nhwc = true;
|
||||
}
|
||||
} else {
|
||||
ORT_UNUSED_PARAMETER(tensor);
|
||||
|
@ -132,45 +95,205 @@ Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, bool NHWC>
|
||||
Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected) const {
|
||||
#if CUDNN_MAJOR >= 9
|
||||
#if !defined(__CUDACC__)
|
||||
|
||||
template <typename T, bool Layout>
|
||||
Status Conv<T, Layout>::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims,
|
||||
const onnxruntime::TensorShapeVector& w_dims,
|
||||
const Tensor* B,
|
||||
const Tensor* Z,
|
||||
const TensorShapeVector& y_dims,
|
||||
cudnnContext* handle,
|
||||
const cudnn_frontend::HeurMode_t heur_mode,
|
||||
const std::vector<int64_t>& pads,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
const bool bias_expected,
|
||||
const bool fuse_bias,
|
||||
const bool fuse_act,
|
||||
const bool w_in_nhwc,
|
||||
const bool use_tf32) const {
|
||||
s_.bias_fused = fuse_bias;
|
||||
s_.act_fused = fuse_act;
|
||||
s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change
|
||||
s_.cudnn_fe_graph = std::make_unique<cudnn_frontend::graph::Graph>();
|
||||
cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType<CudaT>();
|
||||
s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type);
|
||||
if (data_type == cudnn_frontend::DataType_t::HALF) {
|
||||
s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
|
||||
} else {
|
||||
s_.cudnn_fe_graph->set_compute_data_type(data_type);
|
||||
}
|
||||
|
||||
s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get());
|
||||
s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get());
|
||||
|
||||
auto conv_options = cudnn_frontend::graph::Conv_fprop_attributes()
|
||||
.set_pre_padding(std::vector<int64_t>(pads.begin(),
|
||||
pads.begin() + pads.size() / 2))
|
||||
.set_post_padding(std::vector<int64_t>(pads.begin() + pads.size() / 2, pads.end()))
|
||||
.set_stride(strides)
|
||||
.set_dilation(dilations);
|
||||
s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_fprop(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options);
|
||||
auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get();
|
||||
|
||||
if (!bias_expected && B == nullptr) {
|
||||
s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y;
|
||||
} else {
|
||||
int64_t bias_size;
|
||||
if (B != nullptr) {
|
||||
bias_size = B->Shape()[0];
|
||||
} else {
|
||||
bias_size = w_dims[0];
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_z_tensor;
|
||||
if (Z) {
|
||||
const auto& z_shape = Z->Shape().AsShapeVector();
|
||||
cudnn_fe_z_tensor = CudnnFeTensor(z_shape, "z", data_type, Layout == LAYOUT_NHWC).Get();
|
||||
} else if (fuse_bias && Layout == LAYOUT_NCHW) {
|
||||
// Z is required for NCHW precompiled kernels in cuDNN
|
||||
s_.z_data = s_.y_data;
|
||||
cudnn_fe_z_tensor = cudnn_fe_y_tensor;
|
||||
}
|
||||
|
||||
if (fuse_bias) {
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> add_output;
|
||||
if (cudnn_fe_z_tensor.has_value()) {
|
||||
s_.cudnn_fe_Z = s_.cudnn_fe_graph->tensor(cudnn_fe_z_tensor.value());
|
||||
auto add_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD);
|
||||
add_output = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_Z, add_options);
|
||||
} else {
|
||||
add_output = s_.cudnn_fe_conv_Y;
|
||||
}
|
||||
|
||||
onnxruntime::TensorShapeVector b_dims;
|
||||
for (size_t i = 0; i < x_dims.size(); i++) {
|
||||
b_dims.push_back(i == 1 ? bias_size : 1);
|
||||
}
|
||||
auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get();
|
||||
auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD);
|
||||
s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor);
|
||||
s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(add_output, s_.cudnn_fe_B, bias_options);
|
||||
} else {
|
||||
s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y;
|
||||
|
||||
TensorShapeVector b_dims(y_dims.size(), 1);
|
||||
TensorShapeVector b_strides(y_dims.size(), 1);
|
||||
b_dims[1] = bias_size;
|
||||
b_strides[0] = bias_size;
|
||||
if (Z) {
|
||||
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().AsShapeVector(),
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
cudnn_fe_z_tensor->get_stride()));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>(), b_strides));
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType<CudaT>(), cudnn_fe_y_tensor.get_stride()));
|
||||
|
||||
/* Creating an own CUDNN Frontend graph for the bias addition.
|
||||
s_.cudnn_fe_bias_graph = std::make_unique<cudnn_frontend::graph::Graph>();
|
||||
s_.cudnn_fe_bias_graph->set_io_data_type(data_type)
|
||||
.set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ?
|
||||
cudnn_frontend::DataType_t::FLOAT : data_type)
|
||||
.set_intermediate_data_type(data_type);
|
||||
s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor<NHWC>(y_dims, "x", data_type).Get());
|
||||
|
||||
s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor);
|
||||
s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options);
|
||||
s_.cudnn_fe_bias_Y->set_output(true);
|
||||
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate());
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle));
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode}));
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle));
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/
|
||||
}
|
||||
}
|
||||
if (fuse_act && s_.cudnn_fe_act_attr.has_value()) {
|
||||
auto& activation_attr = s_.cudnn_fe_act_attr.value();
|
||||
s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr);
|
||||
}
|
||||
|
||||
s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim());
|
||||
s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride());
|
||||
s_.cudnn_fe_Y->set_output(true);
|
||||
|
||||
try {
|
||||
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate());
|
||||
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle));
|
||||
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode}));
|
||||
} catch (const std::exception& ex) {
|
||||
std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(),
|
||||
"with the cudnn frontend json:\n", s_.cudnn_fe_graph->print());
|
||||
return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message);
|
||||
}
|
||||
|
||||
if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE});
|
||||
|
||||
try {
|
||||
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle));
|
||||
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle));
|
||||
} catch (const std::exception& ex) {
|
||||
if (!fuse_bias && !fuse_act && use_tf32) {
|
||||
std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(),
|
||||
"with the cudnn frontend json:\n", s_.cudnn_fe_graph->print());
|
||||
return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message);
|
||||
}
|
||||
|
||||
// Try fallback.
|
||||
return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, Z, y_dims, handle, heur_mode,
|
||||
pads, strides, dilations, bias_expected, false, false, w_in_nhwc, true);
|
||||
}
|
||||
|
||||
s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T, bool Layout>
|
||||
Status Conv<T, Layout>::UpdateState(OpKernelContext* context, bool bias_expected) const {
|
||||
constexpr bool channels_last = Layout;
|
||||
|
||||
// set X
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& x_shape = X->Shape();
|
||||
// X incl. x_dims is in NHWC Format iff. NHWC == true
|
||||
const auto x_dims = x_shape.AsShapeVector();
|
||||
|
||||
s_.x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
|
||||
s_.element_size = X->DataType()->Size();
|
||||
// set W
|
||||
bool w_in_nhwc;
|
||||
const Tensor* W;
|
||||
if (!W_) {
|
||||
W = context->Input<Tensor>(1);
|
||||
w_in_nhwc = W_already_nhwc;
|
||||
// Dims and memory layout are in NCHW format
|
||||
} else {
|
||||
W = W_.get();
|
||||
w_in_nhwc = true;
|
||||
// W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC
|
||||
}
|
||||
const TensorShape& w_shape = W->Shape();
|
||||
auto w_dims = w_shape.AsShapeVector();
|
||||
onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector();
|
||||
s_.w_data = reinterpret_cast<const CudaT*>(W->Data<T>());
|
||||
|
||||
// Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC.
|
||||
constexpr bool channels_last = NHWC;
|
||||
if constexpr (channels_last) {
|
||||
if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of dimensions of X and W should be 4 for channels_last format (NHWC)");
|
||||
}
|
||||
}
|
||||
|
||||
// set B
|
||||
// Always in NCHW format
|
||||
const Tensor* B = nullptr;
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
B = context->Input<Tensor>(2);
|
||||
s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else {
|
||||
s_.b_data = nullptr;
|
||||
}
|
||||
|
||||
// set Z
|
||||
const Tensor* Z = nullptr;
|
||||
if (context->InputCount() >= 4) {
|
||||
const Tensor* Z = context->Input<Tensor>(3);
|
||||
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType<CudaT>()));
|
||||
Z = context->Input<Tensor>(3);
|
||||
s_.z_data = reinterpret_cast<const CudaT*>(Z->Data<T>());
|
||||
} else {
|
||||
s_.z_data = nullptr;
|
||||
|
@ -183,13 +306,12 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
|
||||
if (w_dims_changed) {
|
||||
s_.last_w_dims = gsl::make_span(w_dims);
|
||||
s_.cached_benchmark_results.clear();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last));
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc));
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last));
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc));
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
|
@ -211,59 +333,46 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t M = W->Shape()[0];
|
||||
if (channels_last) {
|
||||
|
||||
if constexpr (channels_last) {
|
||||
y_dims.push_back(N);
|
||||
} else {
|
||||
y_dims.insert(y_dims.begin(), {N, M});
|
||||
}
|
||||
|
||||
bool post_slicing_required = false;
|
||||
TensorShapeVector slice_starts;
|
||||
slice_starts.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_ends;
|
||||
slice_ends.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_axes;
|
||||
slice_axes.reserve(kernel_rank);
|
||||
|
||||
constexpr size_t spatial_dim_start = channels_last ? 1 : 2;
|
||||
const size_t spatial_dim_end = spatial_dim_start + kernel_rank;
|
||||
TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end);
|
||||
|
||||
TensorShapeVector y_dims_with_adjusted_pads(y_dims);
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape,
|
||||
strides, dilations, pads, y_dims, y_dims_with_adjusted_pads,
|
||||
post_slicing_required, slice_starts, slice_ends, slice_axes,
|
||||
channels_last));
|
||||
if (channels_last) {
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(spatial_shape, kernel_shape,
|
||||
strides, dilations, pads, y_dims));
|
||||
if constexpr (channels_last) {
|
||||
y_dims.push_back(M);
|
||||
y_dims_with_adjusted_pads.push_back(M);
|
||||
}
|
||||
|
||||
ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size());
|
||||
s_.y_dims = gsl::make_span(y_dims);
|
||||
s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads;
|
||||
s_.post_slicing_required = post_slicing_required;
|
||||
s_.slice_starts = slice_starts;
|
||||
s_.slice_ends = slice_ends;
|
||||
s_.slice_axes = slice_axes;
|
||||
|
||||
s_.Y = context->Output(0, TensorShape(s_.y_dims));
|
||||
if (post_slicing_required) {
|
||||
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
|
||||
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
// No post slicing needed. Fill the output tensor's buffer directly.
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
const CUDAExecutionProvider* cuda_ep =
|
||||
static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
|
||||
TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()};
|
||||
TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads;
|
||||
TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()};
|
||||
TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()};
|
||||
|
||||
if constexpr (channels_last) {
|
||||
x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1));
|
||||
y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1));
|
||||
x_dims_cudnn.erase(x_dims_cudnn.end() - 1);
|
||||
y_dims_cudnn.erase(y_dims_cudnn.end() - 1);
|
||||
|
||||
if (w_in_nhwc) {
|
||||
w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1));
|
||||
w_dims_cudnn.erase(w_dims_cudnn.end() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (kernel_rank < 2) {
|
||||
// TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D]
|
||||
// especially for EXHAUSTIVE algo search which may result in a better algo selection.
|
||||
|
@ -276,7 +385,7 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
|
||||
x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1);
|
||||
y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1);
|
||||
w_dims.insert(w_dims.begin() + 2, 1);
|
||||
w_dims_cudnn.insert(w_dims.begin() + 2, 1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.begin(), 0);
|
||||
kernel_shape.insert(kernel_shape.begin(), 1);
|
||||
|
@ -285,7 +394,7 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
} else {
|
||||
x_dims_cudnn.push_back(1);
|
||||
y_dims_cudnn.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
w_dims_cudnn.push_back(1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
|
@ -294,188 +403,105 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
}
|
||||
}
|
||||
|
||||
if (w_dims_changed) {
|
||||
if (!channels_last) {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(w_dims[0]),
|
||||
static_cast<int>(w_dims[3]),
|
||||
static_cast<int>(w_dims[1]),
|
||||
static_cast<int>(w_dims[2])));
|
||||
}
|
||||
}
|
||||
|
||||
// We must delay returning early until here so that the weight dims have been cached properly
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (channels_last) {
|
||||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(x_dims_cudnn[0]),
|
||||
static_cast<int>(x_dims_cudnn[3]),
|
||||
static_cast<int>(x_dims_cudnn[1]),
|
||||
static_cast<int>(x_dims_cudnn[2])));
|
||||
auto handle = GetCudnnHandle(context);
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(y_dims_cudnn[0]),
|
||||
static_cast<int>(y_dims_cudnn[3]),
|
||||
static_cast<int>(y_dims_cudnn[1]),
|
||||
static_cast<int>(y_dims_cudnn[2])));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
|
||||
#if !defined(__CUDACC__)
|
||||
cudnn_frontend::HeurMode_t heur_mode;
|
||||
switch (cudnn_conv_algo) {
|
||||
case 0:
|
||||
heur_mode = cudnn_frontend::HeurMode_t::B;
|
||||
break;
|
||||
case 1:
|
||||
heur_mode = cudnn_frontend::HeurMode_t::A;
|
||||
break;
|
||||
case 2:
|
||||
heur_mode = cudnn_frontend::HeurMode_t::FALLBACK;
|
||||
LOGS_DEFAULT(WARNING) << "OP " << CudaKernel::Node().OpType() << "(" << CudaKernel::Node().Name()
|
||||
<< ") running in Fallback mode. May be extremely slow.";
|
||||
break;
|
||||
default:
|
||||
heur_mode = cudnn_frontend::HeurMode_t::A;
|
||||
break;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group),
|
||||
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
const auto use_tf32 = cuda_ep->UseTF32();
|
||||
// fuse if this op is part of a FusedConv or if the EP is set to fuse ops
|
||||
const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_;
|
||||
const auto fuse_act = is_fused_node_;
|
||||
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
const auto& b_shape = B->Shape();
|
||||
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = b_shape[0];
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
// s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else if (bias_expected) {
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = w_dims[0];
|
||||
auto malloc_size = b_dims[1] * sizeof(CudaT);
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
if (s_.b_zero) {
|
||||
CUDA_CALL_THROW(cudaFree(s_.b_zero));
|
||||
s_.b_zero = nullptr;
|
||||
}
|
||||
CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size));
|
||||
CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
|
||||
}
|
||||
|
||||
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
|
||||
// set math type to tensor core before algorithm search
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
if (!UseTF32()) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnConvolutionFwdAlgoPerf_t perf;
|
||||
int algo_count = 1;
|
||||
int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
|
||||
ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo);
|
||||
switch (cudnn_conv_algo) {
|
||||
case 0: {
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
s_.y_data,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf,
|
||||
algo_search_workspace.get(),
|
||||
max_ws_size));
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.w_desc,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf));
|
||||
break;
|
||||
|
||||
default:
|
||||
perf.algo = kDefaultConvAlgo;
|
||||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
|
||||
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (std::is_same<T, float>::value && !UseTF32()) {
|
||||
perf.mathType = CUDNN_FMA_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
}
|
||||
s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType});
|
||||
}
|
||||
const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType));
|
||||
s_.algo = perf.algo;
|
||||
s_.workspace_bytes = perf.memory;
|
||||
ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, Z, y_dims_cudnn, handle, heur_mode,
|
||||
std::vector<int64_t>(pads.begin(),
|
||||
pads.end()),
|
||||
std::vector<int64_t>(strides.begin(),
|
||||
strides.end()),
|
||||
std::vector<int64_t>(dilations.begin(),
|
||||
dilations.end()),
|
||||
bias_expected, fuse_bias, fuse_act, w_in_nhwc, use_tf32));
|
||||
#endif
|
||||
} else {
|
||||
// set Y
|
||||
s_.Y = context->Output(0, s_.y_dims);
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (s_.post_slicing_required) {
|
||||
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, bool NHWC>
|
||||
Status Conv<T, NHWC>::ComputeInternal(OpKernelContext* context) const {
|
||||
template <typename T, bool Layout>
|
||||
Status Conv<T, Layout>::ComputeInternal(OpKernelContext* context) const {
|
||||
std::lock_guard<OrtMutex> lock(s_.mutex);
|
||||
ORT_RETURN_IF_ERROR(UpdateState(context));
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto alpha = Consts<CudaT>::One;
|
||||
const auto beta = Consts<CudaT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = GetWorkSpace(context->GetComputeStream());
|
||||
const auto alpha = onnxruntime::cuda::Consts<CudaT>::One;
|
||||
auto cudnn_handle = GetCudnnHandle(context);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle,
|
||||
&alpha,
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.algo,
|
||||
workspace.get(),
|
||||
s_.workspace_bytes,
|
||||
&beta,
|
||||
s_.y_tensor,
|
||||
s_.y_data));
|
||||
if (nullptr != s_.b_data) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data,
|
||||
#if !defined(__CUDACC__)
|
||||
s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast<void*>(s_.x_data));
|
||||
s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast<void*>(s_.w_data));
|
||||
s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data);
|
||||
if (s_.bias_fused && s_.b_data != nullptr) {
|
||||
s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast<void*>(s_.b_data));
|
||||
}
|
||||
if (s_.bias_fused && s_.z_data != nullptr) {
|
||||
s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast<void*>(s_.z_data));
|
||||
if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) {
|
||||
// memset Z if it's required for a succesful fusion
|
||||
CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes()));
|
||||
}
|
||||
}
|
||||
auto ws = GetWorkSpace(context->GetComputeStream());
|
||||
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle,
|
||||
s_.variant_pack,
|
||||
ws.get()));
|
||||
|
||||
if (!s_.bias_fused && s_.z_data != nullptr) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data,
|
||||
&alpha, s_.y_tensor, s_.y_data));
|
||||
}
|
||||
// To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions
|
||||
// This may have lead to extra results that are unnecessary and hence we slice that off here
|
||||
if (s_.post_slicing_required) {
|
||||
ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads),
|
||||
s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts,
|
||||
s_.slice_ends, s_.slice_axes, s_.element_size));
|
||||
if (!s_.bias_fused && s_.b_data != nullptr) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data,
|
||||
&alpha, s_.y_tensor, s_.y_data));
|
||||
|
||||
/* For the standalone bias addition graph.
|
||||
s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data);
|
||||
s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data);
|
||||
s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast<void*>(s_.b_data));
|
||||
CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle,
|
||||
s_.variant_pack_bias,
|
||||
GetWorkSpace(context->GetComputeStream()).get()));*/
|
||||
}
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -536,6 +562,7 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// template instantiation for NhwcConv
|
||||
|
|
|
@ -6,6 +6,12 @@
|
|||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
@ -150,6 +156,24 @@ struct CudnnConvState {
|
|||
CudnnTensor z_tensor;
|
||||
const void* z_data = nullptr;
|
||||
CudnnConvolutionDescriptor conv_desc;
|
||||
bool bias_fused = true;
|
||||
bool act_fused = true;
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
std::unique_ptr<cudnn_frontend::graph::Graph> cudnn_fe_graph;
|
||||
std::unique_ptr<cudnn_frontend::graph::Graph> cudnn_fe_bias_graph;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_X;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_W;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_conv_Y;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_Z;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_B;
|
||||
std::shared_ptr<cudnn_frontend::graph::Tensor_attributes> cudnn_fe_Y;
|
||||
|
||||
std::optional<cudnn_frontend::graph::Pointwise_attributes> cudnn_fe_act_attr = std::nullopt;
|
||||
|
||||
std::unordered_map<std::shared_ptr<cudnn_frontend::graph::Tensor_attributes>, void*> variant_pack;
|
||||
std::unordered_map<std::shared_ptr<cudnn_frontend::graph::Tensor_attributes>, void*> variant_pack_bias;
|
||||
#endif
|
||||
|
||||
struct PerfResultParams {
|
||||
decltype(AlgoPerfType().algo) algo;
|
||||
|
@ -183,7 +207,7 @@ enum : size_t {
|
|||
|
||||
// ONNX Conv operator uses NCHW format for input, weights and output.
|
||||
// NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels.
|
||||
template <typename T, bool NHWC>
|
||||
template <typename T, bool Layout>
|
||||
class Conv : public CudaKernel {
|
||||
public:
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
@ -205,12 +229,32 @@ class Conv : public CudaKernel {
|
|||
}
|
||||
|
||||
Status UpdateState(OpKernelContext* context, bool bias_expected = false) const;
|
||||
|
||||
#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9
|
||||
Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims,
|
||||
const onnxruntime::TensorShapeVector& w_dims,
|
||||
const Tensor* B,
|
||||
const Tensor* Z,
|
||||
const TensorShapeVector& y_dims,
|
||||
cudnnContext* handle,
|
||||
const cudnn_frontend::HeurMode_t heur_mode,
|
||||
const std::vector<int64_t>& pads,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
const bool bias_expected,
|
||||
const bool fuse_bias,
|
||||
const bool fuse_act,
|
||||
const bool w_in_nhwc,
|
||||
const bool use_tf32) const;
|
||||
#endif
|
||||
|
||||
ConvAttributes conv_attrs_;
|
||||
mutable CudnnConvState<cudnnConvolutionFwdAlgoPerf_t> s_;
|
||||
constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
static const cudnnConvolutionFwdAlgo_t kAllAlgos[];
|
||||
std::unique_ptr<Tensor> W_;
|
||||
bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain
|
||||
bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain
|
||||
bool is_fused_node_ = false; // ensures the node is fused although the session option is not set
|
||||
bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain
|
||||
};
|
||||
|
||||
Status SliceOutUnwantedOutputSection(cudaStream_t stream,
|
||||
|
|
|
@ -0,0 +1,484 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) 2023 NVIDIA Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
#include "core/common/span_utils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/tensor/transpose.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
|
||||
#include "core/providers/cuda/tensor/slice.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
static const cudnnConvolutionFwdAlgo_t kAllAlgos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
|
||||
static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) {
|
||||
return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz);
|
||||
}
|
||||
|
||||
size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState<cudnnConvolutionFwdAlgoPerf_t>& s,
|
||||
const cudnnConvolutionFwdAlgo_t* algo, int n_algo) {
|
||||
// TODO: get maximum available size from memory arena
|
||||
size_t free, total;
|
||||
CUDA_CALL_THROW(cudaMemGetInfo(&free, &total));
|
||||
// Assuming 10% of fragmentation
|
||||
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
cudnnStatus_t err;
|
||||
size_t sz;
|
||||
err = GetWorkspaceSize(handle, s, algo[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue;
|
||||
max_ws_size = sz;
|
||||
}
|
||||
return max_ws_size;
|
||||
}
|
||||
|
||||
Status SliceOutUnwantedOutputSection(cudaStream_t stream,
|
||||
const void* input_data, gsl::span<const int64_t> input_dims,
|
||||
void* output_data,
|
||||
const gsl::span<const int64_t>& output_dims,
|
||||
const gsl::span<const int64_t>& starts,
|
||||
const gsl::span<const int64_t>& ends,
|
||||
const gsl::span<const int64_t>& axes,
|
||||
size_t element_size) {
|
||||
SliceOp::PrepareForComputeMetadata compute_metadata(input_dims);
|
||||
|
||||
ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata));
|
||||
|
||||
// As a sanity check, ensure that the slice operator's output shape matches with the expected output shape
|
||||
ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims));
|
||||
|
||||
return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size);
|
||||
}
|
||||
|
||||
template <typename T, bool NHWC>
|
||||
Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected) const {
|
||||
// set X
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& x_shape = X->Shape();
|
||||
const auto x_dims = x_shape.AsShapeVector();
|
||||
s_.x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
|
||||
s_.element_size = X->DataType()->Size();
|
||||
bool w_in_nhwc;
|
||||
const Tensor* W;
|
||||
if (!W_) {
|
||||
W = context->Input<Tensor>(1);
|
||||
w_in_nhwc = W_already_nhwc;
|
||||
// Dims and memory layout are in NCHW format
|
||||
} else {
|
||||
W = W_.get();
|
||||
w_in_nhwc = true;
|
||||
// W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC
|
||||
}
|
||||
const TensorShape& w_shape = W->Shape();
|
||||
auto w_dims = w_shape.AsShapeVector();
|
||||
s_.w_data = reinterpret_cast<const CudaT*>(W->Data<T>());
|
||||
|
||||
// Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC.
|
||||
constexpr bool channels_last = NHWC;
|
||||
if constexpr (channels_last) {
|
||||
if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of dimensions of X and W should be 4 for channels_last format (NHWC)");
|
||||
}
|
||||
}
|
||||
|
||||
// set B
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else {
|
||||
s_.b_data = nullptr;
|
||||
}
|
||||
// set Z
|
||||
if (context->InputCount() >= 4) {
|
||||
const Tensor* Z = context->Input<Tensor>(3);
|
||||
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType<CudaT>()));
|
||||
s_.z_data = reinterpret_cast<const CudaT*>(Z->Data<T>());
|
||||
} else {
|
||||
s_.z_data = nullptr;
|
||||
}
|
||||
bool input_dims_changed = (s_.last_x_dims != x_dims);
|
||||
bool w_dims_changed = (s_.last_w_dims != w_dims);
|
||||
if (input_dims_changed || w_dims_changed) {
|
||||
if (input_dims_changed)
|
||||
s_.last_x_dims = gsl::make_span(x_dims);
|
||||
|
||||
if (w_dims_changed) {
|
||||
s_.last_w_dims = gsl::make_span(w_dims);
|
||||
s_.cached_benchmark_results.clear();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc));
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc));
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
ConvPadVector pads(conv_attrs_.pads);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_rank * 2, 0);
|
||||
}
|
||||
TensorShapeVector dilations(conv_attrs_.dilations);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_rank, 1);
|
||||
}
|
||||
TensorShapeVector strides(conv_attrs_.strides);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_rank, 1);
|
||||
}
|
||||
|
||||
TensorShapeVector y_dims;
|
||||
y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C'
|
||||
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t M = W->Shape()[0];
|
||||
if (channels_last) {
|
||||
y_dims.push_back(N);
|
||||
} else {
|
||||
y_dims.insert(y_dims.begin(), {N, M});
|
||||
}
|
||||
|
||||
bool post_slicing_required = false;
|
||||
TensorShapeVector slice_starts;
|
||||
slice_starts.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_ends;
|
||||
slice_ends.reserve(kernel_rank);
|
||||
|
||||
TensorShapeVector slice_axes;
|
||||
slice_axes.reserve(kernel_rank);
|
||||
|
||||
constexpr size_t spatial_dim_start = channels_last ? 1 : 2;
|
||||
const size_t spatial_dim_end = spatial_dim_start + kernel_rank;
|
||||
TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end);
|
||||
|
||||
TensorShapeVector y_dims_with_adjusted_pads(y_dims);
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape,
|
||||
strides, dilations, pads, y_dims, y_dims_with_adjusted_pads,
|
||||
post_slicing_required, slice_starts, slice_ends, slice_axes,
|
||||
channels_last));
|
||||
if (channels_last) {
|
||||
y_dims.push_back(M);
|
||||
y_dims_with_adjusted_pads.push_back(M);
|
||||
}
|
||||
|
||||
ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size());
|
||||
s_.y_dims = gsl::make_span(y_dims);
|
||||
s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads;
|
||||
s_.post_slicing_required = post_slicing_required;
|
||||
s_.slice_starts = slice_starts;
|
||||
s_.slice_ends = slice_ends;
|
||||
s_.slice_axes = slice_axes;
|
||||
|
||||
s_.Y = context->Output(0, TensorShape(s_.y_dims));
|
||||
if (post_slicing_required) {
|
||||
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
|
||||
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
// No post slicing needed. Fill the output tensor's buffer directly.
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
|
||||
const CUDAExecutionProvider* cuda_ep =
|
||||
static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
|
||||
TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()};
|
||||
TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads;
|
||||
if (kernel_rank < 2) {
|
||||
// TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D]
|
||||
// especially for EXHAUSTIVE algo search which may result in a better algo selection.
|
||||
// ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to
|
||||
// inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape
|
||||
// [N,C,D] to [N,C,1,D], especially on A100, and especially for ConvGrad.
|
||||
// PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems
|
||||
// to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT.
|
||||
// See PR #7348 and #7702 for more context.
|
||||
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
|
||||
x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1);
|
||||
y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1);
|
||||
w_dims.insert(w_dims.begin() + 2, 1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.begin(), 0);
|
||||
kernel_shape.insert(kernel_shape.begin(), 1);
|
||||
strides.insert(strides.begin(), 1);
|
||||
dilations.insert(dilations.begin(), 1);
|
||||
} else {
|
||||
x_dims_cudnn.push_back(1);
|
||||
y_dims_cudnn.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
pads.insert(pads.begin() + kernel_rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
strides.push_back(1);
|
||||
dilations.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (w_dims_changed) {
|
||||
if (!channels_last) {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
} else if (w_in_nhwc) {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(w_dims[0]),
|
||||
static_cast<int>(w_dims[3]),
|
||||
static_cast<int>(w_dims[1]),
|
||||
static_cast<int>(w_dims[2])));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(w_dims[0]),
|
||||
static_cast<int>(w_dims[1]),
|
||||
static_cast<int>(w_dims[2]),
|
||||
static_cast<int>(w_dims[3])));
|
||||
}
|
||||
}
|
||||
|
||||
// We must delay returning early until here so that the weight dims have been cached properly
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (channels_last) {
|
||||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(x_dims_cudnn[0]),
|
||||
static_cast<int>(x_dims_cudnn[3]),
|
||||
static_cast<int>(x_dims_cudnn[1]),
|
||||
static_cast<int>(x_dims_cudnn[2])));
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC,
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
static_cast<int>(y_dims_cudnn[0]),
|
||||
static_cast<int>(y_dims_cudnn[3]),
|
||||
static_cast<int>(y_dims_cudnn[1]),
|
||||
static_cast<int>(y_dims_cudnn[2])));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group),
|
||||
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
const auto& b_shape = B->Shape();
|
||||
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = b_shape[0];
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
// s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
|
||||
} else if (bias_expected) {
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = w_dims[0];
|
||||
auto malloc_size = b_dims[1] * sizeof(CudaT);
|
||||
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
if (s_.b_zero) {
|
||||
CUDA_CALL_THROW(cudaFree(s_.b_zero));
|
||||
s_.b_zero = nullptr;
|
||||
}
|
||||
CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size));
|
||||
CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
|
||||
}
|
||||
|
||||
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
|
||||
// set math type to tensor core before algorithm search
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
if (!UseTF32()) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnConvolutionFwdAlgoPerf_t perf;
|
||||
int algo_count = 1;
|
||||
int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
|
||||
ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo);
|
||||
switch (cudnn_conv_algo) {
|
||||
case 0: {
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
s_.y_data,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf,
|
||||
algo_search_workspace.get(),
|
||||
max_ws_size));
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
GetCudnnHandle(context),
|
||||
s_.x_tensor,
|
||||
s_.w_desc,
|
||||
s_.conv_desc,
|
||||
s_.y_tensor,
|
||||
1, // requestedAlgoCount
|
||||
&algo_count, // returnedAlgoCount
|
||||
&perf));
|
||||
break;
|
||||
|
||||
default:
|
||||
perf.algo = kDefaultConvAlgo;
|
||||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
|
||||
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (std::is_same<T, float>::value && !UseTF32()) {
|
||||
perf.mathType = CUDNN_FMA_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
}
|
||||
s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType});
|
||||
}
|
||||
const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType));
|
||||
s_.algo = perf.algo;
|
||||
s_.workspace_bytes = perf.memory;
|
||||
} else {
|
||||
// set Y
|
||||
s_.Y = context->Output(0, s_.y_dims);
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (s_.post_slicing_required) {
|
||||
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream());
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.memory_for_cudnn_conv_results.get());
|
||||
} else {
|
||||
s_.y_data = reinterpret_cast<CudaT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, bool NHWC>
|
||||
Status Conv<T, NHWC>::ComputeInternal(OpKernelContext* context) const {
|
||||
std::lock_guard<OrtMutex> lock(s_.mutex);
|
||||
ORT_RETURN_IF_ERROR(UpdateState(context));
|
||||
if (s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto alpha = Consts<CudaT>::One;
|
||||
const auto beta = Consts<CudaT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = GetWorkSpace(context->GetComputeStream());
|
||||
auto cudnn_handle = GetCudnnHandle(context);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle,
|
||||
&alpha,
|
||||
s_.x_tensor,
|
||||
s_.x_data,
|
||||
s_.w_desc,
|
||||
s_.w_data,
|
||||
s_.conv_desc,
|
||||
s_.algo,
|
||||
workspace.get(),
|
||||
s_.workspace_bytes,
|
||||
&beta,
|
||||
s_.y_tensor,
|
||||
s_.y_data));
|
||||
if (nullptr != s_.b_data) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data,
|
||||
&alpha, s_.y_tensor, s_.y_data));
|
||||
}
|
||||
// To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions
|
||||
// This may have lead to extra results that are unnecessary and hence we slice that off here
|
||||
if (s_.post_slicing_required) {
|
||||
ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads),
|
||||
s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts,
|
||||
s_.slice_ends, s_.slice_axes, s_.element_size));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) {
|
||||
}
|
||||
|
||||
CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() {
|
||||
if (desc_ != nullptr) {
|
||||
cudnnDestroyConvolutionDescriptor(desc_);
|
||||
desc_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status CudnnConvolutionDescriptor::Set(
|
||||
size_t rank,
|
||||
const gsl::span<const int64_t>& pads,
|
||||
const gsl::span<const int64_t>& strides,
|
||||
const gsl::span<const int64_t>& dilations,
|
||||
int groups,
|
||||
cudnnConvolutionMode_t mode,
|
||||
cudnnDataType_t data_type,
|
||||
bool use_tf32) {
|
||||
if (!desc_)
|
||||
CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));
|
||||
|
||||
InlinedVector<int, kTensorShapeSmallBufferElementsSize> pad_dims(rank);
|
||||
InlinedVector<int, kTensorShapeSmallBufferElementsSize> stride_dims(rank);
|
||||
InlinedVector<int, kTensorShapeSmallBufferElementsSize> dilation_dims(rank);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
pad_dims[i] = gsl::narrow_cast<int>(pads[i]);
|
||||
stride_dims[i] = gsl::narrow_cast<int>(strides[i]);
|
||||
dilation_dims[i] = gsl::narrow_cast<int>(dilations[i]);
|
||||
}
|
||||
|
||||
// This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
|
||||
// Setting math_type to CUDNN_DATA_FLOAT for half input
|
||||
cudnnDataType_t math_type = data_type;
|
||||
if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor(
|
||||
desc_,
|
||||
gsl::narrow_cast<int>(rank),
|
||||
pad_dims.data(),
|
||||
stride_dims.data(),
|
||||
dilation_dims.data(),
|
||||
mode,
|
||||
math_type));
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups));
|
||||
|
||||
// Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
|
||||
// See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
|
||||
if (data_type == CUDNN_DATA_HALF) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
|
||||
} else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
|
@ -4,16 +4,16 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_pch.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Error handling
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template <typename ERRTYPE, bool THRW>
|
||||
template <typename ERRTYPE, bool THRW, typename SUCCTYPE = ERRTYPE>
|
||||
std::conditional_t<THRW, void, Status> CudaCall(
|
||||
ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line);
|
||||
ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg,
|
||||
const char* file, const int line);
|
||||
|
||||
#define CUDA_CALL(expr) (CudaCall<cudaError, false>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
|
||||
#define CUBLAS_CALL(expr) (CudaCall<cublasStatus_t, false>((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_pch.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
#if !defined(__CUDACC__)
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
namespace onnxruntime {
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Error handling
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#define CUDNN_FE_CALL(expr) (CudaCall<cudnn_frontend::error_t, false, \
|
||||
cudnn_frontend::error_code_t>((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \
|
||||
cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__))
|
||||
#define CUDNN_FE_CALL_THROW(expr) (CudaCall<cudnn_frontend::error_t, true, \
|
||||
cudnn_frontend::error_code_t>((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \
|
||||
cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__))
|
||||
} // namespace onnxruntime
|
|
@ -15,9 +15,6 @@
|
|||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
// Generalize library calls to be use in template functions
|
||||
inline cublasStatus_t
|
||||
cublasGemmHelper(cublasHandle_t handle,
|
||||
|
@ -84,7 +81,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
|
|||
half* C, int ldc,
|
||||
const cudaDeviceProp& prop,
|
||||
bool /*use_tf32*/) {
|
||||
const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance();
|
||||
const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance();
|
||||
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode());
|
||||
if (half_options->IsCompute16F()) {
|
||||
return cublasGemmEx(handle,
|
||||
|
@ -127,7 +124,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
|
|||
half* C, int ldc,
|
||||
const cudaDeviceProp& prop,
|
||||
bool /*use_tf32*/) {
|
||||
const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance();
|
||||
const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance();
|
||||
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode());
|
||||
if (half_options->IsCompute16F()) {
|
||||
// The alpha and beta shall have same precision as compute type.
|
||||
|
@ -162,8 +159,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle,
|
|||
#if defined(USE_CUDA)
|
||||
inline cublasStatus_t cublasGemmHelper(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m,
|
||||
int n, int k, const BFloat16* alpha, const BFloat16* A, int lda,
|
||||
const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc,
|
||||
int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda,
|
||||
const onnxruntime::BFloat16* B, int ldb, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc,
|
||||
const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
@ -174,8 +171,9 @@ inline cublasStatus_t cublasGemmHelper(
|
|||
}
|
||||
#else
|
||||
inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
|
||||
const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*,
|
||||
BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) {
|
||||
const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int,
|
||||
const onnxruntime::BFloat16*, int, const onnxruntime::BFloat16*,
|
||||
onnxruntime::BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
#endif
|
||||
|
@ -250,7 +248,7 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
|
|||
int batch_count,
|
||||
const cudaDeviceProp& prop,
|
||||
bool /*use_tf32*/) {
|
||||
const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance();
|
||||
const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance();
|
||||
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode());
|
||||
if (half_options->IsCompute16F()) {
|
||||
return cublasGemmBatchedEx(handle,
|
||||
|
@ -286,9 +284,9 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
|
|||
#if defined(USE_CUDA)
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[],
|
||||
int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta,
|
||||
BFloat16* Carray[], int ldc, int batch_count,
|
||||
int m, int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* Aarray[],
|
||||
int lda, const onnxruntime::BFloat16* Barray[], int ldb, const onnxruntime::BFloat16* beta,
|
||||
onnxruntime::BFloat16* Carray[], int ldc, int batch_count,
|
||||
const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
@ -300,8 +298,9 @@ inline cublasStatus_t cublasGemmBatchedHelper(
|
|||
}
|
||||
#else
|
||||
inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int,
|
||||
const BFloat16*, const BFloat16*[], int, const BFloat16*[], int,
|
||||
const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&,
|
||||
const onnxruntime::BFloat16*, const onnxruntime::BFloat16*[], int,
|
||||
const onnxruntime::BFloat16*[], int, const onnxruntime::BFloat16*,
|
||||
onnxruntime::BFloat16*[], int, int, const cudaDeviceProp&,
|
||||
bool /*use_tf32*/) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
@ -314,12 +313,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
int m, int n, int k,
|
||||
const float* alpha,
|
||||
const float* A, int lda,
|
||||
long long int strideA,
|
||||
int64_t strideA,
|
||||
const float* B, int ldb,
|
||||
long long int strideB,
|
||||
int64_t strideB,
|
||||
const float* beta,
|
||||
float* C, int ldc,
|
||||
long long int strideC,
|
||||
int64_t strideC,
|
||||
int batch_count,
|
||||
const cudaDeviceProp& prop,
|
||||
bool use_tf32) {
|
||||
|
@ -349,12 +348,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
int m, int n, int k,
|
||||
const double* alpha,
|
||||
const double* A, int lda,
|
||||
long long int strideA,
|
||||
int64_t strideA,
|
||||
const double* B, int ldb,
|
||||
long long int strideB,
|
||||
int64_t strideB,
|
||||
const double* beta,
|
||||
double* C, int ldc,
|
||||
long long int strideC,
|
||||
int64_t strideC,
|
||||
int batch_count,
|
||||
const cudaDeviceProp& /*prop*/,
|
||||
bool /*use_tf32*/) {
|
||||
|
@ -376,16 +375,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
int m, int n, int k,
|
||||
const __half* alpha,
|
||||
const __half* A, int lda,
|
||||
long long int strideA,
|
||||
int64_t strideA,
|
||||
const __half* B, int ldb,
|
||||
long long int strideB,
|
||||
int64_t strideB,
|
||||
const __half* beta,
|
||||
__half* C, int ldc,
|
||||
long long int strideC,
|
||||
int64_t strideC,
|
||||
int batch_count,
|
||||
const cudaDeviceProp& prop,
|
||||
bool /*use_tf32*/) {
|
||||
const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance();
|
||||
const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance();
|
||||
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode());
|
||||
if (half_options->IsCompute16F()) {
|
||||
return cublasGemmStridedBatchedEx(handle,
|
||||
|
@ -425,16 +424,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
int m, int n, int k,
|
||||
const float* alpha,
|
||||
const __half* A, int lda,
|
||||
long long int strideA,
|
||||
int64_t strideA,
|
||||
const __half* B, int ldb,
|
||||
long long int strideB,
|
||||
int64_t strideB,
|
||||
const float* beta,
|
||||
__half* C, int ldc,
|
||||
long long int strideC,
|
||||
int64_t strideC,
|
||||
int batch_count,
|
||||
const cudaDeviceProp& prop,
|
||||
bool /*use_tf32*/) {
|
||||
const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance();
|
||||
const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance();
|
||||
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode());
|
||||
if (half_options->IsCompute16F()) {
|
||||
// The alpha and beta shall have same precision as compute type.
|
||||
|
@ -472,10 +471,10 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
inline cublasStatus_t cublasGemmStridedBatchedHelper(
|
||||
cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const BFloat16* alpha, const BFloat16* A, int lda,
|
||||
long long int strideA, const BFloat16* B, int ldb,
|
||||
long long int strideB, const BFloat16* beta, BFloat16* C, int ldc,
|
||||
long long int strideC, int batch_count,
|
||||
const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda,
|
||||
int64_t strideA, const onnxruntime::BFloat16* B, int ldb,
|
||||
int64_t strideB, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc,
|
||||
int64_t strideC, int batch_count,
|
||||
const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
@ -488,9 +487,9 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(
|
|||
#else
|
||||
inline cublasStatus_t cublasGemmStridedBatchedHelper(
|
||||
cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int,
|
||||
int, const BFloat16*, const BFloat16*, int, long long int,
|
||||
const BFloat16*, int, long long int, const BFloat16*, BFloat16*,
|
||||
int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) {
|
||||
int, const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, int64_t,
|
||||
const onnxruntime::BFloat16*, int, int64_t, const onnxruntime::BFloat16*, onnxruntime::BFloat16*,
|
||||
int, int64_t, int, const cudaDeviceProp&, bool /*use_tf32*/) {
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
#endif
|
||||
|
@ -531,4 +530,5 @@ cublasStatus_t cublasCopyHelper(
|
|||
cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy);
|
||||
|
||||
cublasStatus_t cublasCopyHelper(
|
||||
cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy);
|
||||
cudaStream_t stream, cublasHandle_t handle, int n, const onnxruntime::BFloat16* x,
|
||||
int incx, onnxruntime::BFloat16* y, int incy);
|
||||
|
|
|
@ -1782,6 +1782,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const
|
|||
cuda_options_converted.cudnn_conv_use_max_workspace = 1;
|
||||
cuda_options_converted.enable_cuda_graph = 0;
|
||||
cuda_options_converted.prefer_nhwc = 0;
|
||||
cuda_options_converted.fuse_conv_bias = 0;
|
||||
cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0;
|
||||
cuda_options_converted.enable_skip_layer_norm_strict_mode = 0;
|
||||
cuda_options_converted.use_ep_level_unified_stream = 0;
|
||||
|
|
|
@ -30,7 +30,8 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes,
|
|||
bool use_float16,
|
||||
bool weight_is_initializer = false) {
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
// NHWC implementation doesn't handle W in NHWC layout if it's not an initializer
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && weight_is_initializer;
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
|
||||
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
|
||||
|
||||
|
|
|
@ -438,6 +438,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
if (enable_cuda) {
|
||||
#ifdef USE_CUDA
|
||||
OrtCUDAProviderOptionsV2 cuda_options;
|
||||
cuda_options.device_id = device_id;
|
||||
cuda_options.do_copy_in_default_stream = true;
|
||||
cuda_options.use_tf32 = false;
|
||||
// TODO: Support arena configuration for users of test runner
|
||||
|
|
|
@ -1993,67 +1993,22 @@ TEST_F(GraphTransformationTests, NotWhereFusion) {
|
|||
ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where
|
||||
}
|
||||
|
||||
#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS)
|
||||
// Conv->Add->Relu will be transformed to FusedConv
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
|
||||
}
|
||||
|
||||
// Currently the ConvAddRelu fusion is only backed by a float kernel for the
|
||||
// the CUDA EP.
|
||||
|
||||
// When we see the corresponding pattern for the fp16 data type, the fusion
|
||||
// should not be triggered as there is no kernel to back the fused pattern.
|
||||
|
||||
// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of
|
||||
// data type support in other compatible EPs is still yet to be ascertained.
|
||||
|
||||
// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove
|
||||
// this test.
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Add"], 1);
|
||||
ASSERT_EQ(op_to_count["Relu"], 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered)
|
||||
ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered)
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
// Conv->Add->Relu will be left intact since there is Identity depend on Add
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
#if defined(USE_JSEP)
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
node.SetExecutionProviderType(kJsExecutionProvider);
|
||||
}
|
||||
#else
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
#endif
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1);
|
||||
|
@ -2073,9 +2028,15 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) {
|
|||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
#if defined(USE_JSEP)
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
node.SetExecutionProviderType(kJsExecutionProvider);
|
||||
}
|
||||
#else
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
#endif
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
|
@ -2165,18 +2126,14 @@ TEST_F(GraphTransformationTests, FuseConvActivation) {
|
|||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
#ifdef USE_CUDA
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
#elif defined(USE_ROCM)
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
#elif defined(USE_JSEP)
|
||||
#if defined(USE_JSEP)
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kJsExecutionProvider);
|
||||
}
|
||||
#else
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
#endif
|
||||
std::map<std::string, int> op_to_count_before_fusion = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count_before_fusion[model.second] >= 1);
|
||||
|
@ -2187,14 +2144,7 @@ TEST_F(GraphTransformationTests, FuseConvActivation) {
|
|||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count_after_fusion = CountOpsInGraph(graph);
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
std::set<std::string> cuda_rocm_supported = {"Relu"};
|
||||
if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) {
|
||||
ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]);
|
||||
} else {
|
||||
ASSERT_EQ(op_to_count_after_fusion[model.second], 0);
|
||||
}
|
||||
#elif defined(USE_JSEP)
|
||||
#if defined(USE_JSEP)
|
||||
std::set<std::string> js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"};
|
||||
if (js_supported.find(model.second) == js_supported.end()) {
|
||||
ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]);
|
||||
|
|
|
@ -25,6 +25,7 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
const std::initializer_list<float>& expected_output,
|
||||
const vector<int64_t>& expected_output_shape,
|
||||
bool weight_is_initializer = false,
|
||||
optional<float> epsilon = optional<float>(),
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "",
|
||||
int opset = 7) {
|
||||
|
@ -56,11 +57,13 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
|
||||
test.AddOutput<float>("Y", expected_output_shape, expected_output);
|
||||
|
||||
if (epsilon.has_value()) {
|
||||
test.SetOutputTolerance(*epsilon);
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> excluded_providers(attributes.excluded_providers);
|
||||
// Disable TensorRT because weight as input is not supported
|
||||
excluded_providers.insert(kTensorrtExecutionProvider);
|
||||
// Disable CUDA NHWC execution provider as it is currently flaky
|
||||
excluded_providers.insert(kCudaNHWCExecutionProvider);
|
||||
|
||||
// QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs.
|
||||
excluded_providers.insert(kQnnExecutionProvider);
|
||||
|
@ -189,10 +192,15 @@ TEST(ConvTest, Conv1D_Bias) {
|
|||
vector<int64_t> Y_shape = {2, 1, 4};
|
||||
auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f,
|
||||
0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f};
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
|
||||
// For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only,
|
||||
// as TF32 has a 10 bit mantissa.
|
||||
float epsilon = 1.1e-5f;
|
||||
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon);
|
||||
|
||||
// CoreML EP requires weight to be an initializer
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true);
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon);
|
||||
}
|
||||
|
||||
// Conv47
|
||||
|
@ -240,7 +248,7 @@ TEST(ConvTest, Conv1D_Invalid_Input_Shape) {
|
|||
vector<int64_t> X_shape = {1, 1, 1};
|
||||
vector<int64_t> dummy_shape = {1, 1, 2};
|
||||
auto dummy_vals = {0.0f, 0.0f};
|
||||
TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false,
|
||||
TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional<float>(),
|
||||
OpTester::ExpectResult::kExpectFailure,
|
||||
"Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. "
|
||||
"Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2",
|
||||
|
@ -263,7 +271,7 @@ TEST(ConvTest, Conv2D_Invalid_Input_Shape) {
|
|||
vector<int64_t> dummy_shape = {2, 2, 1, 2};
|
||||
auto dummy_vals = {-0.0f, 0.0f, -0.0f, -0.0f,
|
||||
-0.0f, 0.0f, -0.0f, -0.0f};
|
||||
TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false,
|
||||
TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional<float>(),
|
||||
OpTester::ExpectResult::kExpectFailure,
|
||||
"Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. "
|
||||
"Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0",
|
||||
|
@ -620,7 +628,12 @@ TEST(ConvTest, Conv3D_Bias) {
|
|||
-0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f,
|
||||
-0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f,
|
||||
-0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f};
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
|
||||
// For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only,
|
||||
// as TF32 has a 10 bit mantissa.
|
||||
float epsilon = 2.1e-4f;
|
||||
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon);
|
||||
}
|
||||
|
||||
TEST(ConvTest, Conv2D_group) {
|
||||
|
@ -902,7 +915,8 @@ TEST(ConvTest, ConvDimWithZero) {
|
|||
// not handled by ACL
|
||||
attrs.excluded_providers.insert(kAclExecutionProvider);
|
||||
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, OpTester::ExpectResult::kExpectSuccess, "", 10);
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, optional<float>(),
|
||||
OpTester::ExpectResult::kExpectSuccess, "", 10);
|
||||
}
|
||||
|
||||
TEST(ConvTest, Conv1D_asymmetric_padding) {
|
||||
|
|
1
setup.py
1
setup.py
|
@ -196,6 +196,7 @@ try:
|
|||
to_preload_cann = []
|
||||
|
||||
cuda_dependencies = [
|
||||
"libcuda.so.1",
|
||||
"libcublas.so.11",
|
||||
"libcublas.so.12",
|
||||
"libcublasLt.so.11",
|
||||
|
|
|
@ -11,7 +11,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
|
||||
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
|
||||
version: 1.0.173
|
||||
version: 1.0.175
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# The private ADO project
|
||||
|
@ -22,7 +22,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
|
||||
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
|
||||
version: 1.0.173
|
||||
version: 1.0.175
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# You can add more ADO accounts at here.
|
||||
|
|
Загрузка…
Ссылка в новой задаче