sync to flash attention kernel 2.5.9 and add document of how to write custom op (#757)
* sync to flash attention kernel 2.5.9 * support users to overload GetMayInplace and ReleaseMayInplace * Undo the change for pybind11 dependency
This commit is contained in:
Родитель
b436d09459
Коммит
95d65e4ec0
|
@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no
|
|||
|
||||
add_compile_definitions(USE_CUDA)
|
||||
|
||||
set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
|
||||
set(OCOS_USE_FLASH_ATTENTION OFF)
|
||||
if (OCOS_USE_FLASH_ATTENTION)
|
||||
message(STATUS "Enable flash attention")
|
||||
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# How to write custom ops
|
||||
|
||||
Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.
|
||||
|
||||
## Basic scenario
|
||||
|
||||
You have 2 ways to write a custom op: by writing a function, or by writing a structure.
|
||||
|
||||
### Custom op in the form of function
|
||||
|
||||
If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:
|
||||
|
||||
```C++
|
||||
const Ort::Custom::Tensor<T>&
|
||||
// or
|
||||
const Ort::Custom::Tensor<T>*
|
||||
```
|
||||
|
||||
For the inputs that are optional, their type would be like:
|
||||
|
||||
```C++
|
||||
std::optional<const Ort::Custom::Tensor<T>*>
|
||||
```
|
||||
|
||||
The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.
|
||||
|
||||
The function will return the type **OrtStatusPtr**
|
||||
|
||||
Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.
|
||||
|
||||
### Custom op in the form of structure
|
||||
|
||||
If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:
|
||||
|
||||
```C++
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op
|
||||
|
||||
OrtStatusPtr Compute(...) const // This function computes the customized kernel.
|
||||
```
|
||||
|
||||
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
|
||||
|
||||
## Advanced scenario
|
||||
|
||||
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
|
||||
|
||||
```C++
|
||||
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
|
||||
// you can provide your own implementation to specify the ith input is in CPU or GPU.
|
||||
static OrtMemType GetInputMemoryType(size_t input_index)
|
||||
|
||||
// You can specify input i shares the same memory with output j if possible, by allocating
|
||||
// two array with same length for the pointer input_index and output_index seperately, and
|
||||
// then let (*input_index)[k] = i and (*output_index)[k] = j.
|
||||
// The return value is the length of the allocated array.
|
||||
static size_t GetMayInplace(int** input_index, int** output_index)
|
||||
|
||||
// Release the allocated array from the GetMayInplace() function.
|
||||
static void ReleaseMayInplace(int* input_index, int* output_index)
|
||||
```
|
|
@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if ORT_API_VERSION >= 18
|
||||
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
|
||||
return 0;
|
||||
};
|
||||
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
|
||||
#endif
|
||||
}
|
||||
|
||||
const std::string op_name_;
|
||||
|
|
|
@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
|
|||
template <typename T>
|
||||
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct CustomOp_defined_getMayInplace : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct CustomOp_defined_releaseMayInplace : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};
|
||||
|
||||
template <typename CustomOpKernel>
|
||||
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
|
||||
using ComputeFunction = decltype(&CustomOpKernel::Compute);
|
||||
|
@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
|
|||
};
|
||||
}
|
||||
|
||||
#if ORT_API_VERSION >= 18
|
||||
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
|
||||
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
|
||||
return CustomOpKernel::GetMayInplace(input_index, output_index);
|
||||
};
|
||||
}
|
||||
if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
|
||||
OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
|
||||
CustomOpKernel::ReleaseMayInplace(input_index, output_index);
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
|
||||
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
|
||||
if (api == nullptr) {
|
||||
|
|
|
@ -10,7 +10,6 @@ This enables more flexibility and control over model execution, thus expanding t
|
|||
|
||||
__author__ = "Microsoft"
|
||||
|
||||
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path
|
||||
from ._ocos import Opdef, PyCustomOpDef
|
||||
|
@ -66,6 +65,10 @@ if _lib_only:
|
|||
gen_processing_models = _unimplemented
|
||||
OrtPyFunction = _unimplemented
|
||||
ort_inference = _unimplemented
|
||||
PyOrtFunction = _unimplemented
|
||||
optimize_model = _unimplemented
|
||||
make_onnx_model = _unimplemented
|
||||
ONNXRuntimeError = _unimplemented
|
||||
|
||||
else:
|
||||
__all__ += _offline_api
|
||||
|
|
|
@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params {
|
|||
// The indices to index into the KV cache.
|
||||
int* __restrict__ cache_batch_idx = nullptr;
|
||||
|
||||
// Paged KV cache
|
||||
int * __restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
float rp_dropout;
|
||||
|
||||
// Local window size
|
||||
int window_size_left = -1;
|
||||
int window_size_right = -1;
|
||||
|
@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params {
|
|||
|
||||
int num_splits = 0; // For split-KV version
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
|
||||
const cudaDeviceProp* dprops = nullptr;
|
||||
};
|
||||
|
||||
|
|
|
@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params,
|
|||
bool is_bf16,
|
||||
bool kv_bsnh = true,
|
||||
int window_size_left = -1,
|
||||
int window_size_right = -1) {
|
||||
int window_size_right = -1,
|
||||
bool paged_KV = false,
|
||||
int page_block_size = -1) {
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q;
|
||||
params.k_ptr = k;
|
||||
|
@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params,
|
|||
|
||||
if (cu_seqlens_q_d == nullptr) {
|
||||
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
|
||||
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
|
||||
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
|
||||
params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
|
||||
params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
|
||||
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
|
||||
} else {
|
||||
params.q_batch_stride = 0;
|
||||
|
@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params,
|
|||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
|
||||
params.rp_dropout = 1.f;
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
params.alibi_slopes_batch_stride = 0;
|
||||
|
||||
// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
|
||||
// local and causal, meaning when we have local window size
|
||||
params.is_causal = is_causal;
|
||||
|
@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
|
|||
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
||||
cudaStream_t stream,
|
||||
void* q, // batch_size x seqlen_q x num_heads x head_size
|
||||
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
|
||||
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
|
||||
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
|
||||
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
|
||||
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* out, // batch_size x seqlen_q x num_heads x head_size
|
||||
|
@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
|
||||
int local_window_size,
|
||||
bool is_rotary_interleaved,
|
||||
bool is_packed_qkv) {
|
||||
bool is_packed_qkv,
|
||||
int32_t* block_table, // batch_size x max_num_blocks_per_seq
|
||||
int32_t max_num_blocks_per_seq,
|
||||
int32_t page_block_size) {
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
|
@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
is_bf16,
|
||||
past_bsnh,
|
||||
local_window_size,
|
||||
is_causal ? 0 : -1);
|
||||
is_causal ? 0 : -1,
|
||||
block_table != nullptr,
|
||||
page_block_size);
|
||||
params.dprops = &dprops;
|
||||
|
||||
if (k_new != nullptr && v_new != nullptr) {
|
||||
|
@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
params.oaccum_ptr = nullptr;
|
||||
}
|
||||
|
||||
params.block_table = block_table;
|
||||
params.block_table_batch_stride = max_num_blocks_per_seq;
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
// Only split kernel supports appending to KV cache
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
|
||||
|
||||
|
|
|
@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
|
|||
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
||||
cudaStream_t stream,
|
||||
void* q, // batch_size x seqlen_q x num_heads x head_size
|
||||
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
|
||||
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
|
||||
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
|
||||
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
|
||||
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* out, // batch_size x seqlen_q x num_heads x head_size
|
||||
|
@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
|
||||
int local_window_size = -1,
|
||||
bool is_rotary_interleaved = false,
|
||||
bool is_packed_qkv = false);
|
||||
bool is_packed_qkv = false,
|
||||
int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq
|
||||
int32_t max_num_blocks_per_seq = -1,
|
||||
int32_t page_block_size = 1);
|
||||
|
||||
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -9,20 +9,20 @@
|
|||
|
||||
namespace flash {
|
||||
|
||||
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
flash::compute_attn<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
(void)params;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
|
||||
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
|
||||
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
#else
|
||||
(void)params;
|
||||
#endif
|
||||
|
@ -38,7 +38,7 @@ __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename Kernel_traits, bool Is_causal>
|
||||
template <typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
|
||||
|
@ -53,23 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
// ORT_ENFORCE(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
// ORT_ENFORCE(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -90,16 +92,18 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
|
||||
auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
|
||||
auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -143,7 +147,7 @@ template <typename T>
|
|||
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -154,7 +158,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
|
||||
});
|
||||
|
@ -168,12 +172,12 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
if constexpr (!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
|
||||
|
@ -192,12 +196,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||
if (is_sm8x) {
|
||||
if constexpr (!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
|
||||
|
@ -220,12 +224,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||
if (is_sm8x) {
|
||||
if constexpr (!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
|
@ -241,7 +245,7 @@ template <typename T>
|
|||
void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) {
|
||||
constexpr int Headdim = 192;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
|
@ -257,9 +261,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
|
@ -280,9 +284,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) {
|
|||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
|
||||
}
|
||||
// 64 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
|
||||
|
|
|
@ -54,10 +54,10 @@ __device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor
|
|||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum) {
|
||||
SumOp<float> sum_op;
|
||||
reduce_(tensor, sum, sum_op);
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
|
@ -212,4 +212,168 @@ inline __device__ void apply_mask_causal_w_idx(
|
|||
}
|
||||
}
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool Is_causal, bool Is_local, bool Has_alibi>
|
||||
struct Mask {
|
||||
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
const int window_size_left, window_size_right;
|
||||
const float alibi_slope;
|
||||
|
||||
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
|
||||
const int window_size_left, const int window_size_right,
|
||||
const float alibi_slope=0.f)
|
||||
: max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q)
|
||||
, window_size_left(window_size_left)
|
||||
, window_size_right(window_size_right)
|
||||
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
|
||||
};
|
||||
|
||||
// Causal_mask: whether this particular iteration needs causal masking
|
||||
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
|
||||
static_assert(Layout::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
|
||||
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
|
||||
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
|
||||
if constexpr (Need_masking) {
|
||||
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
|
||||
// Do we need both row and column indices, or just column incides?
|
||||
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Col_idx_only) {
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// No causal, no local
|
||||
if constexpr (Has_alibi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
if constexpr (!Is_even_MN) {
|
||||
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if constexpr (Has_alibi) {
|
||||
if constexpr (Is_causal) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
} else {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
|
||||
}
|
||||
}
|
||||
if constexpr (Causal_mask) {
|
||||
if (col_idx >= col_idx_limit_right) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (Is_local) {
|
||||
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
|
||||
// Causal and Local already handles MN masking
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
|
|
|
@ -198,6 +198,28 @@ inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
template <typename Layout>
|
||||
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
|
@ -212,6 +234,25 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
template<typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
template <typename MMA_traits, typename Layout>
|
||||
|
|
Загрузка…
Ссылка в новой задаче