36 строки
1.4 KiB
CMake
36 строки
1.4 KiB
CMake
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
|
|
|
# set all triton kernel ops that need to be compiled
|
|
if(onnxruntime_USE_ROCM)
|
|
set(triton_kernel_scripts
|
|
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
|
|
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
|
|
)
|
|
endif()
|
|
|
|
function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
|
|
# compile triton kernel, generate .a and .h files
|
|
set(triton_kernel_compiler "${REPO_ROOT}/tools/ci_build/compile_triton.py")
|
|
set(out_dir "${CMAKE_CURRENT_BINARY_DIR}/triton_kernels")
|
|
set(out_obj_file "${out_dir}/triton_kernel_infos.a")
|
|
set(header_file "${out_dir}/triton_kernel_infos.h")
|
|
|
|
list(TRANSFORM triton_kernel_scripts PREPEND "${REPO_ROOT}/")
|
|
|
|
add_custom_command(
|
|
OUTPUT ${out_obj_file} ${header_file}
|
|
COMMAND Python3::Interpreter ${triton_kernel_compiler}
|
|
--header ${header_file}
|
|
--script_files ${triton_kernel_scripts}
|
|
--obj_file ${out_obj_file}
|
|
DEPENDS ${triton_kernel_scripts} ${triton_kernel_compiler}
|
|
COMMENT "Triton compile generates: ${out_obj_file}"
|
|
)
|
|
add_custom_target(onnxruntime_triton_kernel DEPENDS ${out_obj_file} ${header_file})
|
|
set(${out_triton_kernel_obj_file} ${out_obj_file} PARENT_SCOPE)
|
|
set(${out_triton_kernel_header_dir} ${out_dir} PARENT_SCOPE)
|
|
endfunction()
|