41 строка
1.6 KiB
CMake
41 строка
1.6 KiB
CMake
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
find_package(CUDAToolkit)
|
|
enable_language(CUDA)
|
|
|
|
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
|
|
set(CMAKE_CUDA_STANDARD 17)
|
|
cmake_dependent_option(OCOS_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF)
|
|
option(OCOS_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
|
|
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
|
|
message(STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6")
|
|
set(OCOS_USE_FLASH_ATTENTION OFF)
|
|
set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF)
|
|
endif()
|
|
|
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
|
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
|
|
endif()
|
|
|
|
if(NOT WIN32)
|
|
list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
|
|
endif()
|
|
|
|
# Options passed to cudafe
|
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"")
|
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_compare_with_zero\"")
|
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")
|
|
|
|
add_compile_definitions(USE_CUDA)
|
|
|
|
if (OCOS_USE_FLASH_ATTENTION)
|
|
message(STATUS "Enable flash attention")
|
|
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
|
|
endif()
|
|
if (OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
|
|
message(STATUS "Enable memory efficient attention")
|
|
add_compile_definitions(OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
|
|
endif()
|