зеркало из https://github.com/microsoft/DeepSpeed.git
DeepSpeed-FastGen (#4604)
Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
737ef296cd
Коммит
38b41dffa1
|
@ -0,0 +1,56 @@
|
|||
name: nv-a6000
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
- 'blogs/**'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, a6000]
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:23.03-py3
|
||||
ports:
|
||||
- 80
|
||||
options: --gpus all --shm-size "8G"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
ldd --version
|
||||
nvcc --version
|
||||
nvidia-smi
|
||||
python -c "import torch; print('torch:', torch.__version__, torch)"
|
||||
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
- name: Install transformers
|
||||
run: |
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
git rev-parse --short HEAD
|
||||
python -m pip install .
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
||||
python -m pip install .[dev,1bit,autotuning]
|
||||
ds_report
|
||||
- name: Python environment
|
||||
run: |
|
||||
python -m pip list
|
||||
- name: Unit tests
|
||||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12"
|
|
@ -33,7 +33,7 @@ jobs:
|
|||
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
- name: Compile DeepSpeed Ops
|
||||
run: |
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
- name: DS Report
|
||||
run: |
|
||||
ds_report
|
||||
|
|
|
@ -49,6 +49,7 @@ repos:
|
|||
entry: ./scripts/check-license.py
|
||||
language: python
|
||||
files: \.(py|c|cpp|cu|cc|h|hpp|cuh|hip|tr)$
|
||||
exclude: ^(deepspeed/inference/v2/kernels/ragged_ops/blocked_flash|deepspeed/inference/v2/kernels/cutlass_ops/grouped_gemm)
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.1.0
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
include *.txt README.md
|
||||
include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
|
||||
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
|
||||
recursive-include requirements *.txt
|
||||
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
|
||||
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
|
|
|
@ -153,9 +153,26 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
def total_memory(self, device_index=None):
|
||||
return torch.cuda.get_device_properties(device_index).total_memory
|
||||
|
||||
def _get_nvml_gpu_id(self, torch_gpu_id):
|
||||
"""
|
||||
credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
|
||||
|
||||
Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
|
||||
|
||||
If the latter isn't set return the same id
|
||||
"""
|
||||
# if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
|
||||
return ids[torch_gpu_id] # remap
|
||||
else:
|
||||
return torch_gpu_id
|
||||
|
||||
def available_memory(self, device_index=None):
|
||||
if pynvml:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
|
||||
if device_index is None:
|
||||
device_index = self.current_device()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.free
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
enum ActivationType {
|
||||
GELU = 0,
|
||||
RELU = 1,
|
||||
SILU = 2,
|
||||
GEGLU = 3,
|
||||
ReGLU = 4,
|
||||
SiGLU = 5,
|
||||
IDENTITY = 6,
|
||||
InvalidType = -1
|
||||
};
|
|
@ -11,6 +11,11 @@ used throughout the codebase.
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#define DS_HD_INLINE __host__ __device__ __forceinline__
|
||||
#define DS_D_INLINE __device__ __forceinline__
|
||||
|
|
|
@ -233,6 +233,60 @@ DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
|
|||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t element<ROpType::Add>(const int32_t lhs, const int32_t rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t element<ROpType::Max>(const int32_t lhs, const int32_t rhs)
|
||||
{
|
||||
return (lhs > rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t element<ROpType::Min>(const int32_t lhs, const int32_t rhs)
|
||||
{
|
||||
return (lhs < rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t element<ROpType::Add>(const uint32_t lhs, const uint32_t rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t element<ROpType::Max>(const uint32_t lhs, const uint32_t rhs)
|
||||
{
|
||||
return (lhs > rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t element<ROpType::Min>(const uint32_t lhs, const uint32_t rhs)
|
||||
{
|
||||
return (lhs < rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t element<ROpType::Add>(const int64_t lhs, const int64_t rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t element<ROpType::Max>(const int64_t lhs, const int64_t rhs)
|
||||
{
|
||||
return (lhs > rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t element<ROpType::Min>(const int64_t lhs, const int64_t rhs)
|
||||
{
|
||||
return (lhs < rhs) ? lhs : rhs;
|
||||
}
|
||||
|
||||
/*
|
||||
Reduction initialization primitives
|
||||
*/
|
||||
|
@ -310,6 +364,78 @@ DS_D_INLINE __half2 init<ROpType::Max>()
|
|||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t init<ROpType::Add>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t init<ROpType::Min>()
|
||||
{
|
||||
return 0x7FFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int32_t init<ROpType::Max>()
|
||||
{
|
||||
return 0x80000000;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t init<ROpType::Add>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t init<ROpType::Min>()
|
||||
{
|
||||
return 0xFFFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint32_t init<ROpType::Max>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t init<ROpType::Add>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t init<ROpType::Min>()
|
||||
{
|
||||
return 0x7FFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE int64_t init<ROpType::Max>()
|
||||
{
|
||||
return 0x8000000000000000;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint64_t init<ROpType::Add>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint64_t init<ROpType::Min>()
|
||||
{
|
||||
return 0xFFFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE uint64_t init<ROpType::Max>()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <ROpType Op, typename T>
|
||||
DS_D_INLINE void init(T* data)
|
||||
{
|
||||
|
@ -352,8 +478,8 @@ here (fold is C++17 only and I don't think helps and recursion feels like
|
|||
huge overkill that harms readability) that would be wonderful.
|
||||
*/
|
||||
|
||||
template <ROpType Op, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
||||
template <typename T, ROpType Op, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 1; i < reduce_width; i *= 2) {
|
||||
|
@ -361,8 +487,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
|||
}
|
||||
}
|
||||
|
||||
template <ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
||||
template <typename T, ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 1; i < reduce_width; i *= 2) {
|
||||
|
@ -371,8 +497,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
|||
}
|
||||
}
|
||||
|
||||
template <ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
||||
template <typename T, ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 1; i < reduce_width; i *= 2) {
|
||||
|
@ -382,8 +508,13 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
|||
}
|
||||
}
|
||||
|
||||
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
||||
template <typename T,
|
||||
ROpType Op1,
|
||||
ROpType Op2,
|
||||
ROpType Op3,
|
||||
ROpType Op4,
|
||||
int reduce_width = hw_warp_size>
|
||||
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 1; i < reduce_width; i *= 2) {
|
||||
|
@ -403,16 +534,15 @@ the number of warps in the block (which may exceed that
|
|||
if the block is partitioned or if we do a conservative bound at
|
||||
compile time).
|
||||
*/
|
||||
template <int total_warps, ROpType... Ops>
|
||||
template <typename T, int total_warps, ROpType... Ops>
|
||||
DS_D_INLINE void _block(cg::thread_block& tb,
|
||||
cg::thread_block_tile<hw_warp_size>& warp_arg,
|
||||
float* data)
|
||||
T* data)
|
||||
{
|
||||
constexpr int elems = sizeof...(Ops);
|
||||
// Separated for now in case this no longer is true
|
||||
constexpr int bytes = sizeof(float);
|
||||
constexpr int bytes = sizeof(T);
|
||||
// Unused when `partition_size == 1` or total_warps == 1
|
||||
__shared__ float reduce_buffer[max_warps * elems];
|
||||
__shared__ T reduce_buffer[max_warps * elems];
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
|
||||
|
@ -422,7 +552,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
|
|||
#endif
|
||||
|
||||
// Always perform warp-scope reduction
|
||||
_warp<Ops...>(warp_arg, data);
|
||||
_warp<T, Ops...>(warp_arg, data);
|
||||
|
||||
// If max_warps == 1 let's skip the runtime check
|
||||
if (total_warps != 1) {
|
||||
|
@ -447,7 +577,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
|
|||
init<Ops...>(data);
|
||||
}
|
||||
|
||||
_warp<Ops..., total_warps>(warp_arg, data);
|
||||
_warp<T, Ops..., total_warps>(warp_arg, data);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems; i++) {
|
||||
|
@ -476,7 +606,7 @@ us to obfuscate the details of the partitioned implementation.
|
|||
template <ROpType Op, int warp_bound>
|
||||
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
|
||||
{
|
||||
_block<warp_bound, Op>(tb, warp, &val);
|
||||
_block<float, warp_bound, Op>(tb, warp, &val);
|
||||
}
|
||||
|
||||
template <ROpType Op1, ROpType Op2, int warp_bound>
|
||||
|
@ -486,7 +616,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val2)
|
||||
{
|
||||
float data[2] = {val1, val2};
|
||||
_block<warp_bound, Op1, Op2>(tb, warp, data);
|
||||
_block<float, warp_bound, Op1, Op2>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
}
|
||||
|
@ -499,7 +629,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val3)
|
||||
{
|
||||
float data[3] = {val1, val2, val3};
|
||||
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
|
||||
_block<float, warp_bound, Op1, Op2, Op3>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
val3 = data[2];
|
||||
|
@ -514,7 +644,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val4)
|
||||
{
|
||||
float data[4] = {val1, val2, val3, val4};
|
||||
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
_block<float, warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
val3 = data[2];
|
||||
|
@ -531,10 +661,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
float& val)
|
||||
{
|
||||
if (num_threads <= hw_warp_size) {
|
||||
_warp<Op, num_threads>(warp, &val);
|
||||
_warp<float, Op, num_threads>(warp, &val);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
_block<num_warps, Op>(tb, warp, &val);
|
||||
_block<float, num_warps, Op>(tb, warp, &val);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -547,10 +677,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
float data[2] = {val1, val2};
|
||||
|
||||
if (num_threads <= hw_warp_size) {
|
||||
_warp<Op1, Op2, num_threads>(warp, data);
|
||||
_warp<float, Op1, Op2, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
_block<num_warps, Op1, Op2>(tb, warp, data);
|
||||
_block<float, num_warps, Op1, Op2>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
@ -567,10 +697,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
float data[3] = {val1, val2, val3};
|
||||
|
||||
if (num_threads <= hw_warp_size) {
|
||||
_warp<Op1, Op2, Op3, num_threads>(warp, data);
|
||||
_warp<float, Op1, Op2, Op3, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
_block<num_warps, Op1, Op2, Op3>(tb, warp, data);
|
||||
_block<float, num_warps, Op1, Op2, Op3>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
@ -589,10 +719,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
float data[4] = {val1, val2, val3, val4};
|
||||
|
||||
if (num_threads <= hw_warp_size) {
|
||||
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
|
||||
_warp<float, Op1, Op2, Op3, Op4, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
_block<float, num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
@ -601,4 +731,48 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
val4 = data[3];
|
||||
}
|
||||
|
||||
/*
|
||||
Arg-reduce is a specialization of the above. We only support this with a single reduction
|
||||
parameter. This only works for max/min reductions.
|
||||
*/
|
||||
|
||||
__align__(8) struct IdxReduceResult {
|
||||
/*
|
||||
NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits
|
||||
and the val is the most significant. Changing the order of this declaration
|
||||
will break the code.
|
||||
*/
|
||||
int idx;
|
||||
float val;
|
||||
};
|
||||
|
||||
template <ROpType Op, int warpBound>
|
||||
DS_D_INLINE IdxReduceResult
|
||||
idx_reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float val, int idx)
|
||||
{
|
||||
IdxReduceResult res = {idx, val};
|
||||
|
||||
// Clear out the nan. This shouldn't be an issue for our initial applications
|
||||
if (isnan(val)) res.val = init<Op>();
|
||||
|
||||
// Can do float compares as integers. By packing the index into the lower bits
|
||||
// we can just do a single int64 rather than a branch, compare, and select.
|
||||
// One side benefit of this is that it is by nature a stable algorithm and
|
||||
// will always bias ties to the higher index.
|
||||
int64_t* res_as_int = reinterpret_cast<int64_t*>(&res);
|
||||
|
||||
// The way floating point compare works is normally to perform a sign comparison
|
||||
// and if they match, then do a comparison of the rest of the bits as unsigned
|
||||
// integers. Since we are bundling these, that means for negative values we need
|
||||
// to reverse the sort order, which we can do with an XOR.
|
||||
if (val < 0) { *res_as_int ^= 0x7fffffff00000000; }
|
||||
|
||||
_block<int64_t, warpBound, Op>(tb, warp, res_as_int);
|
||||
|
||||
// Sign bit is preserved, so we can check if we need to invert the mantissa back
|
||||
if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; }
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace reduce
|
||||
|
|
|
@ -18,8 +18,8 @@ using __nv_bfloat162 = __half2;
|
|||
|
||||
inline __device__ float gelu(const float x)
|
||||
{
|
||||
const float sqrt_param = 0.79788456080286535587989211986876f;
|
||||
const float mul_param = 0.044715;
|
||||
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
|
||||
constexpr float mul_param = 0.044715;
|
||||
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
|
||||
}
|
||||
|
||||
|
|
|
@ -2,3 +2,6 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
|
||||
from .v2.engine_v2 import InferenceEngineV2
|
||||
from .v2 import build_hf_engine
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
|
||||
from .engine_v2 import InferenceEngineV2
|
||||
from .engine_factory import build_hf_engine
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from functools import reduce
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor:
|
||||
shape_size = reduce(lambda x, y: x * y, shape)
|
||||
if shape_size == 0:
|
||||
raise ValueError("Cannot create empty tensor with size 0")
|
||||
return tensor.flatten()[:shape_size].view(shape)
|
||||
|
||||
|
||||
def on_device(method) -> torch.Tensor:
|
||||
"""
|
||||
Wraps a method to ensure the returned tensor is on the current device.
|
||||
"""
|
||||
|
||||
def wrapped(self, *args, **kwargs):
|
||||
tensor = method(self, *args, **kwargs)
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return tensor.to(get_accelerator().current_device()).contiguous()
|
||||
return tensor
|
||||
|
||||
return wrapped
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .base_engine import CheckpointEngineBase
|
||||
from .in_memory_engine import InMemoryModelEngine
|
||||
from .huggingface_engine import HuggingFaceCheckpointEngine
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
#from .huggingface_engine import HuggingFaceCheckpointEngine
|
||||
|
||||
MEGATRON = 'megatron'
|
||||
HUGGINGFACE = 'huggingface'
|
||||
|
||||
|
||||
class CheckpointEngineBase(ABC):
|
||||
"""
|
||||
Abstract interface for checkpoint engines to implement.
|
||||
|
||||
There is no ``__init__`` method here by design, since the creation of the checkpoint
|
||||
engine will happen outside the policy/engine code. The tradeoff being made here is
|
||||
that we will write different frontends for different checkpoint engines, but these
|
||||
frontends can be tailored to the specific checkpoint engine/model source needs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
This method should create a generator of tuples of the form (name, parameter) for
|
||||
all parameters in the model. The name should be the fully qualified name of the
|
||||
parameter, and the parameter should be a torch.Tensor.
|
||||
|
||||
The expected use of a checkpoint engine is the following:
|
||||
```python
|
||||
for name, parameter in checkpoint_engine.parameters():
|
||||
container_map.map_param(name, parameter)
|
||||
```
|
||||
For a concrete use example, see ``InferenceV2Policy``.
|
||||
"""
|
||||
...
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from .base_engine import CheckpointEngineBase
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from ..logging import inference_logger
|
||||
|
||||
|
||||
class HuggingFaceCheckpointEngine(CheckpointEngineBase):
|
||||
|
||||
def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
|
||||
super().__init__()
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.auth_token = auth_token
|
||||
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
|
||||
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
|
||||
# Define this property here so we can use it in the model implementation
|
||||
if not hasattr(self.model_config, "max_seq_length"):
|
||||
self.model_config.max_seq_length = self.model_config.max_position_embeddings
|
||||
else:
|
||||
self.model_config.max_seq_length = self.generation_config.max_length
|
||||
|
||||
self._all_ckpt_paths = self._fetch_checkpoint_files()
|
||||
|
||||
def _fetch_checkpoint_files(self):
|
||||
"""
|
||||
Fetch the checkpoint files from the HuggingFace Hub.
|
||||
"""
|
||||
# TODO(jeff): for models like llama-2 the user will have to provide an auth `token`,
|
||||
# currently coming from the ckpt engine init but maybe a catch all kwargs for other
|
||||
# snapshot download parameters would be more flexible.
|
||||
|
||||
# NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
|
||||
# checkpoint files that may be present. Example of all files in the llama-2-7b
|
||||
# repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if os.path.isdir(self.model_name_or_path):
|
||||
self._local_checkpoint_dir = self.model_name_or_path
|
||||
else:
|
||||
self._local_checkpoint_dir = snapshot_download(self.model_name_or_path,
|
||||
allow_patterns=[
|
||||
"*.bin",
|
||||
"*.json",
|
||||
"*.pt",
|
||||
],
|
||||
revision=None,
|
||||
token=self.auth_token)
|
||||
|
||||
assert os.path.isdir(
|
||||
self._local_checkpoint_dir
|
||||
), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint."
|
||||
|
||||
model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json")
|
||||
|
||||
if not os.path.isfile(model_param_json):
|
||||
# We don't need any json as all such HF models will have pytorch_model.bin
|
||||
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')]
|
||||
else:
|
||||
param_map = json.load(open(model_param_json, "r"))
|
||||
|
||||
# weight_map -> { "lm_head.weight": "pytorch_model-00002-of-00002.bin", ... }
|
||||
weight_map = param_map["weight_map"]
|
||||
|
||||
# unique set of all checkpoint files
|
||||
all_checkpoint_files = set(weight_map.values())
|
||||
|
||||
# get absolute path of all unique checkpoint files
|
||||
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, f) for f in all_checkpoint_files]
|
||||
|
||||
return all_checkpoint_files
|
||||
|
||||
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Generator of model parameters (satisfies the CheckpointEngineBase interface).
|
||||
"""
|
||||
for checkpoint in self._all_ckpt_paths:
|
||||
inference_logger().info(f"Loading checkpoint: {checkpoint}")
|
||||
checkpoint_sd = torch.load(checkpoint, map_location='cpu')
|
||||
param_keys = list(checkpoint_sd.keys())
|
||||
for param_name in param_keys:
|
||||
param = checkpoint_sd[param_name]
|
||||
yield param_name, param
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To test, add your auth_token here and run `python huggingface_engine.py`
|
||||
engine = HuggingFaceCheckpointEngine(model_name_or_path="meta-llama/Llama-2-7b-hf",
|
||||
auth_token="hf_xxxxxxxxxxxxxxxxx")
|
||||
for name, param in engine.parameters():
|
||||
print(name, param.shape)
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Iterable, Tuple
|
||||
import torch
|
||||
|
||||
from .base_engine import CheckpointEngineBase
|
||||
|
||||
|
||||
class InMemoryModelEngine(CheckpointEngineBase):
|
||||
"""
|
||||
This "checkpoint" engine uses the existing interface to enable loading parameters into an
|
||||
inference model from a model already instantiated in memory. In general, this is not the
|
||||
recommended way to use the inference engine, and should only be used when absolutely necessary.
|
||||
|
||||
The primary limitation of this approach is that the model must be fully instantiated in memory.
|
||||
In a tensor parallel scenario, this means that the model is either replicated many times in host
|
||||
memory. Currently, it is also recommended to only use this approach for models held in host memory.
|
||||
|
||||
In order to free the memory held by this copy of the model, we delete the model in the first call
|
||||
to `parameters`, so it is not safe to make this call twice.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Create virtual checkpoint engine for the provided module.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to load parameters from.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
for name, parameter in self.model.named_parameters():
|
||||
yield name, parameter
|
||||
|
||||
del self.model
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from deepspeed.pydantic_v1 import Field
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
from .ragged import DSStateManagerConfig
|
||||
|
||||
|
||||
class DeepSpeedTPConfig(DeepSpeedConfigModel):
|
||||
""" Configure tensor parallelism settings """
|
||||
|
||||
tp_size: int = 1
|
||||
""" Number of devices to split the model across using tensor parallelism. """
|
||||
|
||||
|
||||
class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
|
||||
""" Sets parameters for DeepSpeed Inference Engine. """
|
||||
|
||||
tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
|
||||
"""
|
||||
Configuration for tensor parallelism used to split the model across several
|
||||
GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
|
||||
"""
|
||||
|
||||
state_manager: DSStateManagerConfig = Field({}, alias="manager")
|
||||
"""
|
||||
Configuration for managing persistent state
|
||||
"""
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .engine_v2 import InferenceEngineV2
|
||||
from .config_v2 import RaggedInferenceEngineConfig
|
||||
from .checkpoint import HuggingFaceCheckpointEngine
|
||||
from .logging import inference_logger
|
||||
|
||||
|
||||
def build_hf_engine(path: str,
|
||||
engine_config: RaggedInferenceEngineConfig,
|
||||
debug_level: int = logging.INFO,
|
||||
random_weights_config: Any = None,
|
||||
fill_random: bool = False) -> InferenceEngineV2:
|
||||
"""
|
||||
Build an InferenceV2 engine for HuggingFace models.
|
||||
"""
|
||||
# Set up logging
|
||||
inference_logger(level=debug_level)
|
||||
|
||||
# get HF checkpoint engine
|
||||
checkpoint_engine = HuggingFaceCheckpointEngine(path)
|
||||
|
||||
# get model config from HF AutoConfig
|
||||
model_config = checkpoint_engine.model_config
|
||||
|
||||
# get the policy
|
||||
# TODO: generalize this to other models
|
||||
if model_config.model_type == "opt":
|
||||
from .model_implementations.opt.policy import OPTPolicy
|
||||
policy = OPTPolicy(checkpoint_engine, model_config)
|
||||
elif model_config.model_type == "llama":
|
||||
from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
|
||||
policy = Llama2Policy(checkpoint_engine, model_config)
|
||||
elif model_config.model_type == "mistral":
|
||||
from .model_implementations.mistral.policy import MistralPolicy
|
||||
policy = MistralPolicy(checkpoint_engine, model_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type {model_config.model_type}")
|
||||
|
||||
return InferenceEngineV2(policy, engine_config)
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import deepspeed.comm as dist
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.comm.comm import init_distributed
|
||||
|
||||
from .model_implementations import InferenceV2Policy
|
||||
from .logging import inference_logger
|
||||
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
|
||||
from .scheduling_utils import SchedulingError, SchedulingResult
|
||||
|
||||
from .config_v2 import RaggedInferenceEngineConfig
|
||||
|
||||
INFERENCE_MODEL_TIMER = "model-forward-inference"
|
||||
|
||||
|
||||
class InferenceEngineV2:
|
||||
|
||||
_config: RaggedInferenceEngineConfig
|
||||
"""
|
||||
Configuration of the inference engine.
|
||||
"""
|
||||
|
||||
#_model: DSInferenceModelBase
|
||||
"""
|
||||
Inference model supporting ragged inference.
|
||||
"""
|
||||
|
||||
_state_manager: DSStateManager
|
||||
"""
|
||||
Persistent state manager for sequences and KV-cache.
|
||||
"""
|
||||
|
||||
@property
|
||||
def free_blocks(self) -> int:
|
||||
"""
|
||||
Number of free KV blocks.
|
||||
"""
|
||||
return self._state_manager.free_blocks
|
||||
|
||||
def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
|
||||
"""
|
||||
Create the Inference V2 engine.
|
||||
|
||||
Arguments:
|
||||
policy (InferenceV2Policy): Policy for the model implementation. This policy object
|
||||
will be used to build the model and load the checkpoint associated with it.
|
||||
engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine.
|
||||
"""
|
||||
self._config = engine_config
|
||||
self._policy = policy
|
||||
self._base_mp_group = self._initialize_tp_group()
|
||||
|
||||
# Build model from policy
|
||||
inference_logger().info("Building model...")
|
||||
self._model = self._policy.build_model(self._config, self._base_mp_group)
|
||||
inference_logger().info("Model built.")
|
||||
|
||||
# Create state manager
|
||||
self._batch = RaggedBatchWrapper(self._config.state_manager)
|
||||
self._state_manager = DSStateManager(self._config.state_manager,
|
||||
self._model.kv_cache_config(),
|
||||
base_mp_group=self._base_mp_group)
|
||||
self._model.set_state_manager(self._state_manager)
|
||||
|
||||
def _initialize_tp_group(self):
|
||||
"""
|
||||
Implementation of our TP group initialization.
|
||||
"""
|
||||
init_distributed()
|
||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
get_accelerator().set_device(local_rank)
|
||||
|
||||
if local_rank >= self._config.tensor_parallel.tp_size:
|
||||
raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.")
|
||||
|
||||
ranks = list(range(self._config.tensor_parallel.tp_size))
|
||||
return dist.new_group(ranks=ranks)
|
||||
|
||||
def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Put a ragged batch onto the inference engine. This will perform one forward and return
|
||||
a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens
|
||||
are not calculated.
|
||||
|
||||
Arguments:
|
||||
batch_uids: Iterable of uids for the batch on the host
|
||||
batch_tokens: Iterable of token tensors for the batch on the host
|
||||
"""
|
||||
|
||||
token_lens = [len(tokens) for tokens in batch_tokens]
|
||||
schedule_check = self.can_schedule(batch_uids, token_lens)
|
||||
if schedule_check != SchedulingResult.Success:
|
||||
raise SchedulingError(schedule_check)
|
||||
|
||||
self._batch.clear()
|
||||
for uid, tokens in zip(batch_uids, batch_tokens):
|
||||
|
||||
host_seq_desc = self._state_manager.get_or_create_sequence(uid)
|
||||
self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
|
||||
host_seq_desc.pre_forward(tokens.numel())
|
||||
|
||||
# We can disable checks since we already validated schedulability.
|
||||
self._batch.insert_sequence(host_seq_desc, tokens, do_checks=False)
|
||||
|
||||
# Send all metadata to the device
|
||||
self._batch.finalize()
|
||||
|
||||
# Prep all data structures for the actual forward (in anticipation of CG in the future)
|
||||
# and also to amortize some of the costs in a more straightforward way.
|
||||
self._model.prepare_batch(self._batch)
|
||||
|
||||
# Model implementation will pick up in the forward.
|
||||
logits = self._model.forward(self._batch)
|
||||
|
||||
# We return one set of logits per sequence in the batch (saves cost on unembedding)
|
||||
assert logits.shape[0] == self._batch.current_sequences
|
||||
|
||||
for uid in batch_uids:
|
||||
host_seq_desc = self._state_manager.get_sequence(uid)
|
||||
host_seq_desc.post_forward() # Updates sequence metadata.
|
||||
self._model.maybe_free_kv(host_seq_desc)
|
||||
|
||||
return logits
|
||||
|
||||
def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
|
||||
(this UID may not be recognized by the model yet), this will return the number of tokens
|
||||
and blocks to reserve for the request.
|
||||
|
||||
Arguments:
|
||||
uid (int): The UID of the sequence (as tracked by the scheduling entity). If
|
||||
this is a new sequence (with a UID unknown to the inference engine), then
|
||||
an empty placeholder is created to pass to the occupancy logic.
|
||||
n_tokens (int): The number of tokens to hypothetically send.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks
|
||||
required to schedule the sequence.
|
||||
"""
|
||||
seq_desc = self._state_manager.get_sequence(uid)
|
||||
if seq_desc is None:
|
||||
if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences):
|
||||
return (0, 0)
|
||||
seq_desc = PlaceholderSequenceDescriptor()
|
||||
|
||||
req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks)
|
||||
|
||||
return (req_tokens, req_blocks)
|
||||
|
||||
def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult:
|
||||
"""
|
||||
Dry run a batch to determine if it can be scheduled. Placeholder sequences will be
|
||||
created for any UIDs that are unknown to the inference engine.
|
||||
|
||||
Arguments:
|
||||
uids (Iterable[int]): Iterable of UIDs for the batch
|
||||
lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths
|
||||
corresponds to the number of tokens to send in the hypothetical forward; history
|
||||
tokens will be determined via UID lookup and future tokens are disregarded.
|
||||
|
||||
Returns:
|
||||
bool: True if the batch can be scheduled, False otherwise.
|
||||
"""
|
||||
|
||||
cur_seqs = self._state_manager.n_tracked_sequences
|
||||
free_blocks = self._state_manager.free_blocks
|
||||
req_blocks = 0
|
||||
batch_len = 0
|
||||
|
||||
if len(uids) > self._config.state_manager.max_ragged_sequence_count:
|
||||
# Can only compose a batch from a limited number of sequences
|
||||
return SchedulingResult.BatchSequenceLimitExceeded
|
||||
|
||||
for uid, length in zip(uids, lengths):
|
||||
seq_desc = self._state_manager.get_sequence(uid)
|
||||
if seq_desc is None:
|
||||
cur_seqs += 1
|
||||
seq_desc = PlaceholderSequenceDescriptor()
|
||||
|
||||
sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks)
|
||||
|
||||
if sched_len != length:
|
||||
# We ran out of KV cache
|
||||
return SchedulingResult.KVCacheLimitExceeded
|
||||
|
||||
batch_len += length
|
||||
free_blocks -= sched_blocks
|
||||
|
||||
if cur_seqs > self._config.state_manager.max_tracked_sequences:
|
||||
# Would run out of tracking metadata
|
||||
return SchedulingResult.EngineSequenceLimitExceeded
|
||||
|
||||
if batch_len > self._config.state_manager.max_ragged_batch_size:
|
||||
# Would exceed the maximum batch size
|
||||
return SchedulingResult.BatchTokenLimitExceeded
|
||||
|
||||
return SchedulingResult.Success
|
||||
|
||||
def flush(self, uid: int) -> None:
|
||||
"""
|
||||
Remove all state associated with a sequence from the inference engine.
|
||||
|
||||
Arguments:
|
||||
uid (int): The UID of the sequence to flush.
|
||||
"""
|
||||
self._state_manager.flush_sequence(uid)
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NormTypeEnum(Enum):
|
||||
LayerNorm: str = "layer_norm"
|
||||
RMSNorm: str = "rms_norm"
|
||||
|
||||
|
||||
class DtypeEnum(Enum):
|
||||
# The torch dtype must always be the first value (so we return torch.dtype)
|
||||
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
|
||||
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
|
||||
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
|
||||
int8 = torch.int8, "torch.int8", "int8"
|
||||
|
||||
# Copied from https://stackoverflow.com/a/43210118
|
||||
# Allows us to use multiple values for each Enum index and returns first
|
||||
# listed value when Enum is called
|
||||
def __new__(cls, *values):
|
||||
obj = object.__new__(cls)
|
||||
# first value is canonical value
|
||||
obj._value_ = values[0]
|
||||
for other_value in values[1:]:
|
||||
cls._value2member_map_[other_value] = obj
|
||||
obj._all_values = values
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s.%s: %s>" % (
|
||||
self.__class__.__name__,
|
||||
self._name_,
|
||||
", ".join([repr(v) for v in self._all_values]),
|
||||
)
|
||||
|
||||
|
||||
ELEM_SIZES: Dict[torch.dtype, int] = {
|
||||
torch.float16: 2,
|
||||
torch.bfloat16: 2,
|
||||
torch.float32: 4,
|
||||
torch.float64: 8,
|
||||
torch.int8: 1,
|
||||
torch.uint8: 1,
|
||||
torch.int16: 2,
|
||||
torch.int32: 4,
|
||||
torch.int64: 8,
|
||||
torch.bool: 1,
|
||||
}
|
||||
|
||||
|
||||
class ActivationType(IntEnum):
|
||||
"""
|
||||
Types of activations supported by DS-Inference
|
||||
"""
|
||||
|
||||
GELU = 0
|
||||
|
||||
RELU = 1
|
||||
|
||||
SILU = 2
|
||||
|
||||
GEGLU = 3
|
||||
|
||||
ReGLU = 4
|
||||
|
||||
SiGLU = 5
|
||||
|
||||
IDENTITY = 6
|
||||
|
||||
InvalidType = -1
|
||||
|
||||
|
||||
def is_gated(act_fn: ActivationType) -> bool:
|
||||
"""
|
||||
Return True if the given activation function is gated.
|
||||
"""
|
||||
if not isinstance(act_fn, ActivationType):
|
||||
act_fn = ActivationType(act_fn)
|
||||
|
||||
return act_fn in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
|
||||
|
||||
|
||||
def elem_size(dtype: torch.dtype) -> int:
|
||||
"""
|
||||
Return size in bytes of the given dtype.
|
||||
"""
|
||||
try:
|
||||
return ELEM_SIZES[dtype]
|
||||
except KeyError:
|
||||
raise ValueError("Unknown dtype size for {}".format(dtype))
|
||||
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
"""
|
||||
Return ceil(a / b).
|
||||
"""
|
||||
return -(-a // b)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .ds_kernel import DSKernelBase
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .bias_activations import *
|
||||
from .blas_kernels import *
|
||||
from .cuda_layer_norm import *
|
||||
from .cuda_rms_norm import *
|
||||
from .gated_activations import *
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .bias_activation import *
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "bias_activation.h"
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include "ds_kernel_utils.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#define DTYPE_SWITCH(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kBFloat16) { \
|
||||
using scalar_t = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#else
|
||||
#define DTYPE_SWITCH(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
/*
|
||||
In-place bias and activation fusion kernel.
|
||||
*/
|
||||
void bias_activation(torch::Tensor& activation,
|
||||
c10::optional<torch::Tensor>& bias,
|
||||
const int32_t act_type)
|
||||
{
|
||||
const ActivationType atype = static_cast<ActivationType>(act_type);
|
||||
const int32_t rows = activation.size(0);
|
||||
const int32_t cols = activation.size(1);
|
||||
|
||||
TORCH_CHECK(atype == ActivationType::GELU || atype == ActivationType::RELU ||
|
||||
atype == ActivationType::SILU || atype == ActivationType::IDENTITY,
|
||||
"Unsupported activation type for BiasActivation");
|
||||
TORCH_CHECK(activation.dim() == 2, "BiasActivation only supports 2D activation tensors");
|
||||
|
||||
DTYPE_SWITCH(activation.scalar_type(), [&] {
|
||||
scalar_t* activation_ptr = reinterpret_cast<scalar_t*>(activation.data_ptr());
|
||||
|
||||
const scalar_t* bias_ptr;
|
||||
if (bias.has_value()) {
|
||||
TORCH_CHECK(activation.scalar_type() == bias.value().scalar_type(),
|
||||
"BiasActivation activation and bias must have same dtype");
|
||||
bias_ptr = reinterpret_cast<const scalar_t*>(bias.value().data_ptr());
|
||||
} else {
|
||||
bias_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (atype == ActivationType::IDENTITY && bias_ptr == nullptr) { return; }
|
||||
|
||||
launch_bias_activation<scalar_t>(
|
||||
activation_ptr, bias_ptr, rows, cols, atype, c10::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <cassert>
|
||||
#include "activation_type.h"
|
||||
#include "conversion_utils.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "memory_access_utils.h"
|
||||
|
||||
// Default activation function will error out
|
||||
template <ActivationType ActType>
|
||||
DS_D_INLINE float act_fn(float val);
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float act_fn<ActivationType::IDENTITY>(float val)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float act_fn<ActivationType::RELU>(float val)
|
||||
{
|
||||
return val > 0.0f ? val : 0.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float act_fn<ActivationType::GELU>(float val)
|
||||
{
|
||||
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
|
||||
constexpr float mul_param = 0.044715f;
|
||||
return val * 0.5f * (1.0f + tanhf(sqrt_param * (val + mul_param * val * val * val)));
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float act_fn<ActivationType::SILU>(float val)
|
||||
{
|
||||
return val / (1.0f + expf(-val));
|
||||
}
|
||||
|
||||
namespace bias_act {
|
||||
|
||||
constexpr int access_size = 16;
|
||||
constexpr int threads = 512;
|
||||
constexpr int unroll = 4;
|
||||
|
||||
} // namespace bias_act
|
||||
|
||||
template <typename T, ActivationType ActType>
|
||||
__global__ void bias_activation_kernel(T* activation,
|
||||
const T* bias,
|
||||
const int32_t rows,
|
||||
const int32_t cols)
|
||||
{
|
||||
constexpr int vector_T = bias_act::access_size / sizeof(T);
|
||||
|
||||
const int32_t thread_offset = threadIdx.x * vector_T;
|
||||
const int32_t block_offset = blockIdx.x * vector_T * bias_act::unroll * bias_act::threads;
|
||||
const int32_t base_offset = block_offset + thread_offset;
|
||||
|
||||
const int32_t thread_stride = bias_act::threads * vector_T;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < bias_act::unroll; i++) {
|
||||
const int32_t iter_offset = base_offset + i * thread_stride;
|
||||
|
||||
const int32_t row = iter_offset / cols;
|
||||
|
||||
T buffer[vector_T];
|
||||
T bias_buffer[vector_T];
|
||||
|
||||
if (row < rows) {
|
||||
const int32_t col = iter_offset % cols;
|
||||
|
||||
mem_access::load_global<bias_act::access_size>(buffer, activation + iter_offset);
|
||||
mem_access::load_global<bias_act::access_size>(
|
||||
bias_buffer, bias + col, bias != nullptr);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_T; j++) {
|
||||
float val =
|
||||
conversion::to<float>(buffer[j]) + conversion::to<float>(bias_buffer[j]);
|
||||
buffer[j] = conversion::to<T>(act_fn<ActType>(val));
|
||||
}
|
||||
|
||||
mem_access::store_global<bias_act::access_size>(activation + iter_offset, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define ACT_TYPE_SWITCH(ACT_TYPE, ...) \
|
||||
if (ACT_TYPE == ActivationType::IDENTITY) { \
|
||||
constexpr ActivationType act_fn_t = ActivationType::IDENTITY; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (ACT_TYPE == ActivationType::RELU) { \
|
||||
constexpr ActivationType act_fn_t = ActivationType::RELU; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (ACT_TYPE == ActivationType::GELU) { \
|
||||
constexpr ActivationType act_fn_t = ActivationType::GELU; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (ACT_TYPE == ActivationType::SILU) { \
|
||||
constexpr ActivationType act_fn_t = ActivationType::SILU; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
assert(false); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_bias_activation(T* activation,
|
||||
const T* bias,
|
||||
const int32_t n_rows,
|
||||
const int32_t n_cols,
|
||||
const ActivationType activation_type,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int32_t elems_per_block =
|
||||
bias_act::threads * bias_act::unroll * bias_act::access_size / sizeof(T);
|
||||
const int32_t total_elems = n_rows * n_cols;
|
||||
|
||||
const int32_t blocks = (total_elems + elems_per_block - 1) / elems_per_block;
|
||||
|
||||
const dim3 grid(blocks);
|
||||
const dim3 block(bias_act::threads);
|
||||
|
||||
ACT_TYPE_SWITCH(activation_type, [&] {
|
||||
bias_activation_kernel<T, act_fn_t>
|
||||
<<<grid, block, 0, stream>>>(activation, bias, n_rows, n_cols);
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_T(T) \
|
||||
template void launch_bias_activation<T>( \
|
||||
T*, const T*, const int32_t, const int32_t, const ActivationType, cudaStream_t);
|
||||
|
||||
INSTANTIATE_FOR_T(__half);
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_FOR_T(__nv_bfloat16);
|
||||
#endif
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "activation_type.h"
|
||||
|
||||
template <typename T>
|
||||
void launch_bias_activation(T* activation,
|
||||
const T* bias,
|
||||
const int32_t n_rows,
|
||||
const int32_t n_cols,
|
||||
const ActivationType activation_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
void bias_activation(torch::Tensor& activation,
|
||||
c10::optional<torch::Tensor>& bias,
|
||||
const int32_t activation_type);
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ....inference_utils import ActivationType, DtypeEnum
|
||||
from deepspeed.ops.op_builder import InferenceCoreBuilder
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
class CUDABiasActivation(DSKernelBase):
|
||||
"""
|
||||
CUDA implementation of bias activation kernel. This kernel should be deprecated once
|
||||
we are fusing the bias activation into the linear kernel in all scenarios.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_act_fns = [ActivationType.IDENTITY, ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]
|
||||
|
||||
def __init__(self, channels: int, dtype: DtypeEnum, act_fn: ActivationType) -> None:
|
||||
"""
|
||||
Compile and validate for the fused bias-activation kernel.
|
||||
|
||||
Parameters:
|
||||
channels (int): Number of channels to expect in the activation.
|
||||
dtype (torch.dtype): Data type for the input/output. Supported values
|
||||
are DtypeEnum.fp16 and DtypeEnum.bf16.
|
||||
act_fn (ActivationType): Activation function to use. Only IDENTITY, GELU, RELU, and SILU are supported.
|
||||
"""
|
||||
|
||||
if channels % 8 != 0:
|
||||
raise ValueError("channels must be divisible by 8")
|
||||
|
||||
if DtypeEnum(dtype) not in CUDABiasActivation.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
dtype, CUDABiasActivation.supported_dtypes))
|
||||
|
||||
act_fn = ActivationType(act_fn)
|
||||
if act_fn not in CUDABiasActivation.supported_act_fns:
|
||||
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
|
||||
act_fn, CUDABiasActivation.supported_act_fns))
|
||||
|
||||
inf_module = InferenceCoreBuilder().load()
|
||||
self.kernel = inf_module.bias_activation
|
||||
self.act_fn = act_fn
|
||||
|
||||
def __call__(self, activation: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Add an optional bias and perform the non-linear activation function.
|
||||
|
||||
Parameters:
|
||||
activation (torch.Tensor): Input tensor of shape [tokens, channels]
|
||||
bias (torch.Tensor): Optional bias tensor of shape [channels]
|
||||
|
||||
Returns:
|
||||
activation that has been updated in-place
|
||||
"""
|
||||
self.kernel(activation, bias, self.act_fn.value)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .blas_linear import *
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include "blas_utils.h"
|
||||
|
||||
#define DISPATCH_BLAS_MATMUL(T_TYPE, C_TYPE) \
|
||||
if (output.options().dtype() == torch::T_TYPE) { \
|
||||
blas_gemm_ex(output.data_ptr(), \
|
||||
(const void*)weights.data_ptr(), \
|
||||
(const void*)hidden_states.data_ptr(), \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
&alpha, \
|
||||
&beta, \
|
||||
C_TYPE); \
|
||||
}
|
||||
|
||||
void blas_linear(at::Tensor& output, at::Tensor& hidden_states, at::Tensor& weights)
|
||||
{
|
||||
/*
|
||||
Expected shape: output([total_tokens_across_dims], out_neurons)
|
||||
hidden_states([total_tokens_across_dims], in_neurons)
|
||||
weights(out_neurons, in_neurons)
|
||||
|
||||
We are going to assume contiguous for the above shapes.
|
||||
|
||||
The shapes are going to get messed with a little internally to handle column-major
|
||||
GEMMs.
|
||||
*/
|
||||
|
||||
// Number of tokens is N (since the GEMM output is column-major but our Tensor
|
||||
// is row-major, we need to transpose the shapes)
|
||||
const int n = output.numel() / output.size(-1);
|
||||
const int k = weights.size(1);
|
||||
const int m = weights.size(0);
|
||||
|
||||
// A strides
|
||||
const bool trans_a = weights.stride(1) == 1;
|
||||
const int lda = (trans_a) ? weights.stride(0) : weights.stride(1);
|
||||
|
||||
// B strides
|
||||
const bool trans_b = hidden_states.stride(-1) != 1;
|
||||
const int ldb = (trans_b) ? hidden_states.stride(-1) : hidden_states.stride(-2);
|
||||
|
||||
// C strides
|
||||
const int ldc = output.stride(-2);
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
TORCH_CHECK(output.scalar_type() == hidden_states.scalar_type(),
|
||||
"Output and hidden states must have the same scalar type");
|
||||
TORCH_CHECK(output.scalar_type() == weights.scalar_type(),
|
||||
"Output and weights must have the same scalar type");
|
||||
|
||||
// Dispatch the datatypes
|
||||
DISPATCH_BLAS_MATMUL(kFloat, BlasType::FP32);
|
||||
DISPATCH_BLAS_MATMUL(kHalf, BlasType::FP16);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_BLAS_MATMUL(kBFloat16, BlasType::BF16);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_4D_BLAS(T_TYPE, C_TYPE) \
|
||||
if (C.options().dtype() == torch::T_TYPE) { \
|
||||
blas_strided_batched_gemm(C.data_ptr(), \
|
||||
(const void*)A.data_ptr(), \
|
||||
(const void*)B.data_ptr(), \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
&alpha, \
|
||||
&beta, \
|
||||
stride_a, \
|
||||
stride_b, \
|
||||
stride_c, \
|
||||
batch, \
|
||||
C_TYPE); \
|
||||
}
|
||||
|
||||
void blas_4d_matmul(at::Tensor& C, at::Tensor& B, at::Tensor& A)
|
||||
{
|
||||
/*
|
||||
C shape: (batch_size, N, M)
|
||||
A shape: (batch_size, N, K)
|
||||
B shape: (batch_size, K, M)
|
||||
*/
|
||||
|
||||
const int n = C.size(-2);
|
||||
const int k = C.size(-1);
|
||||
const int m = B.size(-1);
|
||||
|
||||
// A strides
|
||||
const bool trans_a = A.stride(-1) == 1;
|
||||
const int lda = (trans_a) ? A.stride(-2) : A.stride(-1);
|
||||
const int stride_a = A.stride(-3);
|
||||
|
||||
// B strides
|
||||
const bool trans_b = B.stride(-1) != 1;
|
||||
const int ldb = (trans_b) ? B.stride(-1) : B.stride(-2);
|
||||
const int stride_b = B.stride(-3);
|
||||
|
||||
// C strides
|
||||
const int ldc = C.stride(-2);
|
||||
const int stride_c = C.stride(-3);
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
const int batch = C.numel() / (n * m);
|
||||
|
||||
// Dispatch the datatypes
|
||||
DISPATCH_4D_BLAS(kFloat, BlasType::FP32);
|
||||
DISPATCH_4D_BLAS(kHalf, BlasType::FP16);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_4D_BLAS(kBFloat16, BlasType::BF16);
|
||||
#endif
|
||||
}
|
||||
|
||||
void create_handle() { BlasContext::getInstance().get_handle(); }
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ....inference_utils import DtypeEnum
|
||||
from deepspeed.ops.op_builder import InferenceCoreBuilder
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
class BlasLibLinear(DSKernelBase):
|
||||
"""
|
||||
Wrapper around the BLAS matmul kernel for FP16/BF16/FP32 for CUDA/RoCM.
|
||||
|
||||
Performs z = x @ y
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]
|
||||
|
||||
def __init__(self, fp_dtype: DtypeEnum):
|
||||
"""
|
||||
Parameters:
|
||||
fp_dtype (torch.dtype): Data type for the input/output. Supported values
|
||||
are torch.float16, torch.bfloat16, and torch.float32.
|
||||
"""
|
||||
fp_dtype = DtypeEnum(fp_dtype)
|
||||
if fp_dtype not in BlasLibLinear.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, BlasLibLinear.supported_dtypes))
|
||||
|
||||
self.inf_module = InferenceCoreBuilder().load()
|
||||
self.inf_module.create_handle()
|
||||
self.kernel = self.inf_module.blas_linear
|
||||
|
||||
def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Matmul kernel as implemented by platform BLAS library. The input must be 2D or larger. If
|
||||
n-dimensional, the leading dimensions are folded into each other:
|
||||
2D: m = x.size(0)
|
||||
3D: m = x.size(0) * x.size(1)
|
||||
4D: m = x.size(0) * x.size(1) * x.size(2) (etc...)
|
||||
All inputs should be contiguous.
|
||||
|
||||
Parameters:
|
||||
output (torch.Tensor): Output tensor. Shape is of [*, out_features]
|
||||
hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features]
|
||||
weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features]
|
||||
|
||||
Returns:
|
||||
z (torch.Tensor): Output tensor. Shape is of [m, n]
|
||||
"""
|
||||
self.kernel(output, hidden_states, weights)
|
||||
return output
|
|
@ -0,0 +1,275 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#ifdef BF16_AVAILABLE
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
#include <mma.h>
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
class BlasContext {
|
||||
/*
|
||||
Slim wrapper for managing the lifetime of the platform's BLAS handle. This should
|
||||
be hipified for ROCm.
|
||||
*/
|
||||
public:
|
||||
BlasContext()
|
||||
{
|
||||
if (cublasCreate(&_handle) != CUBLAS_STATUS_SUCCESS) {
|
||||
auto message = std::string("Fail to create cublas handle.");
|
||||
std::cerr << message << std::endl;
|
||||
throw std::runtime_error(message);
|
||||
}
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH);
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual ~BlasContext() { cublasDestroy(_handle); }
|
||||
|
||||
static BlasContext& getInstance()
|
||||
{
|
||||
// Should always access the singleton through this function.
|
||||
static BlasContext _instance;
|
||||
return _instance;
|
||||
}
|
||||
|
||||
cublasHandle_t get_handle() const { return _handle; }
|
||||
|
||||
private:
|
||||
cublasHandle_t _handle;
|
||||
};
|
||||
|
||||
enum class BlasType { FP32, FP16, BF16 };
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
rocblas_operation get_trans_op(bool do_trans)
|
||||
{
|
||||
return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none;
|
||||
}
|
||||
|
||||
rocblas_datatype get_datatype(BlasType type)
|
||||
{
|
||||
switch (type) {
|
||||
case BlasType::FP32: return rocblas_datatype_f32_r;
|
||||
case BlasType::FP16: return rocblas_datatype_f16_r;
|
||||
case BlasType::BF16: return rocblas_datatype_bf16_r;
|
||||
default: throw std::runtime_error("Unsupported BlasType");
|
||||
}
|
||||
}
|
||||
#else
|
||||
cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T : CUBLAS_OP_N; }
|
||||
|
||||
cublasDataType_t get_datatype(BlasType type)
|
||||
{
|
||||
switch (type) {
|
||||
case BlasType::FP32: return CUDA_R_32F;
|
||||
case BlasType::FP16: return CUDA_R_16F;
|
||||
case BlasType::BF16: return CUDA_R_16BF;
|
||||
default: throw std::runtime_error("Unsupported BlasType");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
int blas_gemm_ex(void* C,
|
||||
const void* A,
|
||||
const void* B,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool transa,
|
||||
bool transb,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
BlasType type)
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
rocblas_operation_t transa_op = get_trans_op(transa);
|
||||
rocblas_operation_t transb_op = get_trans_op(transb);
|
||||
|
||||
rocblas_datatype_t abc_type = get_datatype(type);
|
||||
|
||||
rocblas_status status = rocblas_gemm_ex(BlasContext::getInstance().get_handle(),
|
||||
transa_op,
|
||||
transb_op,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(const void*)alpha,
|
||||
A,
|
||||
abc_type,
|
||||
lda,
|
||||
B,
|
||||
abc_type,
|
||||
ldb,
|
||||
(const void*)beta,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
cublasOperation_t transa_op = get_trans_op(transa);
|
||||
cublasOperation_t transb_op = get_trans_op(transb);
|
||||
|
||||
cublasDataType_t abc_type = get_datatype(type);
|
||||
cublasStatus_t status = cublasGemmEx(BlasContext::getInstance().get_handle(),
|
||||
transa_op,
|
||||
transb_op,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(const void*)alpha,
|
||||
A,
|
||||
abc_type,
|
||||
lda,
|
||||
B,
|
||||
abc_type,
|
||||
ldb,
|
||||
(const void*)beta,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
#endif
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int blas_strided_batched_gemm(void* C,
|
||||
const void* A,
|
||||
const void* B,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool transa,
|
||||
bool transb,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch,
|
||||
BlasType type)
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
rocblas_operation_t transa_op = get_trans_op(transa);
|
||||
rocblas_operation_t transb_op = get_trans_op(transb);
|
||||
|
||||
rocblas_datatype_t abc_type = get_datatype(type);
|
||||
|
||||
rocblas_status status =
|
||||
rocblas_gemm_strided_batched_ex(BlasContext::getInstance()::get_handle(),
|
||||
transa_op,
|
||||
transb_op,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(const void*)alpha,
|
||||
A,
|
||||
abc_type,
|
||||
lda,
|
||||
stride_A,
|
||||
B,
|
||||
abc_type,
|
||||
ldb,
|
||||
stride_B,
|
||||
(const void*)beta,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
stride_C,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
stride_C,
|
||||
batch,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
cublasOperation_t transa_op = get_trans_op(transa);
|
||||
cublasOperation_t transb_op = get_trans_op(transb);
|
||||
|
||||
cublasDataType_t abc_type = get_datatype(type);
|
||||
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(BlasContext::getInstance().get_handle(),
|
||||
transa_op,
|
||||
transb_op,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(const void*)alpha,
|
||||
A,
|
||||
abc_type,
|
||||
lda,
|
||||
stride_A,
|
||||
B,
|
||||
abc_type,
|
||||
ldb,
|
||||
stride_B,
|
||||
(const void*)beta,
|
||||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
stride_C,
|
||||
batch,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
#endif
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
|
||||
batch,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "bias_activation.h"
|
||||
#include "blas.h"
|
||||
#include "gated_activation_kernels.h"
|
||||
#include "layer_norm.h"
|
||||
#include "rms_norm.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
// bias_activation.h
|
||||
m.def("bias_activation", &bias_activation, "DeepSpeed bias activation in CUDA");
|
||||
|
||||
// layer_norm.h
|
||||
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm in CUDA");
|
||||
m.def("pre_layer_norm", &ds_pre_layer_norm, "DeepSpeed pre layer norm in CUDA");
|
||||
m.def("post_layer_norm", &ds_post_layer_norm, "DeepSpeed pre layer norm in CUDA");
|
||||
|
||||
// blas.h
|
||||
m.def("blas_linear", &blas_linear, "Linear implemented by vendor BLAS");
|
||||
m.def("blas_4d_matmul", &blas_4d_matmul, "4D matmul implemented by vendor BLAS");
|
||||
m.def("create_handle", &create_handle, "Create a handle for vendor BLAS");
|
||||
|
||||
// gated_activation_kernels.h
|
||||
m.def("gated_activation", &ds_gated_activation, "DeepSpeed gated activation in CUDA");
|
||||
|
||||
// rms_norm.h
|
||||
m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
|
||||
m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .cuda_ln import *
|
||||
from .cuda_post_ln import *
|
||||
from .cuda_pre_ln import *
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import elem_size
|
||||
from deepspeed.ops.op_builder import InferenceCoreBuilder
|
||||
|
||||
|
||||
class CUDAFPLNBase(DSKernelBase):
|
||||
"""
|
||||
Base class for CUDA LN kernels. They all same the same validation logic,
|
||||
so we can share it here.
|
||||
"""
|
||||
|
||||
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
|
||||
"""
|
||||
Parameters:
|
||||
channels (int): Number of channels in the input tensor. Must be divisible to align
|
||||
to 16 bytes.
|
||||
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
|
||||
are torch.float16, torch.bfloat16, and torch.float32.
|
||||
"""
|
||||
if fp_dtype not in CUDAFPLNBase.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, CUDAFPLNBase.supported_dtypes))
|
||||
|
||||
if elem_size(fp_dtype) * channels % 16 != 0:
|
||||
raise ValueError("channels must be divisible by 16 bytes")
|
||||
|
||||
self.inf_module = InferenceCoreBuilder().load()
|
||||
self.epsilon = epsilon
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from .cuda_fp_ln_base import CUDAFPLNBase
|
||||
|
||||
|
||||
class CUDAFPLN(CUDAFPLNBase):
|
||||
"""
|
||||
Floating point layer norm kernel for CUDA/RoCM.
|
||||
|
||||
Performs: z = ln(x)
|
||||
"""
|
||||
|
||||
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor,
|
||||
beta: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
output_z may alias input_x directly. All Tensors should have the same shape.
|
||||
|
||||
Parameters:
|
||||
output_z (torch.Tensor): Output tensor.
|
||||
input_x (torch.Tensor): Input tensor.
|
||||
gamma (torch.Tensor): Gamma tensor.
|
||||
beta (torch.Tensor): Beta tensor.
|
||||
"""
|
||||
self.inf_module.layer_norm(output_z, input_x, gamma, beta, self.epsilon)
|
||||
return output_z
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from .cuda_fp_ln_base import CUDAFPLNBase
|
||||
|
||||
|
||||
class CUDAFPPostLN(CUDAFPLNBase):
|
||||
"""
|
||||
Floating point post-LayerNorm kernel for CUDA/RoCM.
|
||||
|
||||
Performs: z = ln(x + y)
|
||||
"""
|
||||
|
||||
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, input_y: torch.Tensor, gamma: torch.Tensor,
|
||||
beta: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Either input_x or input_y can alias output_z.
|
||||
|
||||
Parameters:
|
||||
output_z (torch.Tensor): Output tensor.
|
||||
input_x (torch.Tensor): Input tensor.
|
||||
input_y (torch.Tensor): Input tensor.
|
||||
gamma (torch.Tensor): Gamma tensor.
|
||||
beta (torch.Tensor): Beta tensor.
|
||||
|
||||
Returns:
|
||||
output (torch.Tensor): Output tensor.
|
||||
"""
|
||||
self.inf_module.post_layer_norm(output_z, input_x, input_y, gamma, beta, self.epsilon)
|
||||
return output_z
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .cuda_fp_ln_base import CUDAFPLNBase
|
||||
|
||||
|
||||
class CUDAFPPreLN(CUDAFPLNBase):
|
||||
"""
|
||||
Floating point pre-LayerNorm kernel for CUDA/RoCM.
|
||||
|
||||
Performs: z_res = x_res + y_hid
|
||||
z_hid = ln(z_hid)
|
||||
"""
|
||||
|
||||
def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
|
||||
gamma: torch.Tensor, beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
z_res can alias x_res. All non-parameter input/output tensors
|
||||
must have the same shape. z_hid can alias y_hid.
|
||||
|
||||
Parameters:
|
||||
z_res (torch.Tensor): Output residual.
|
||||
z_hid (torch.Tensor): Output hidden states.
|
||||
x_res (torch.Tensor): Input residual.
|
||||
y_hid (torch.Tensor): Input hidden states.
|
||||
gamma (torch.Tensor): Gamma tensor.
|
||||
beta (torch.Tensor): Beta tensor.
|
||||
|
||||
Returns:
|
||||
output (torch.Tensor): Output tensor.
|
||||
"""
|
||||
self.inf_module.pre_layer_norm(z_res, z_hid, x_res, y_hid, gamma, beta, self.epsilon)
|
||||
return z_res, z_hid
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "layer_norm.h"
|
||||
|
||||
#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \
|
||||
if (input.options().dtype() == torch::T_TYPE) { \
|
||||
launch_fused_ln((C_TYPE*)output.data_ptr(), \
|
||||
(const C_TYPE*)input.data_ptr(), \
|
||||
(const C_TYPE*)gamma.data_ptr(), \
|
||||
(const C_TYPE*)beta.data_ptr(), \
|
||||
epsilon, \
|
||||
rows, \
|
||||
elems_per_row, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
void ds_layer_norm(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon)
|
||||
{
|
||||
bool ragged_input = input.dim() == 2;
|
||||
|
||||
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
|
||||
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
|
||||
|
||||
DISPATCH_LAYER_NORM(kFloat, float);
|
||||
DISPATCH_LAYER_NORM(kHalf, __half);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
|
||||
if (input.options().dtype() == torch::T_TYPE) { \
|
||||
launch_fused_post_ln((C_TYPE*)output.data_ptr(), \
|
||||
(const C_TYPE*)input.data_ptr(), \
|
||||
(const C_TYPE*)residual.data_ptr(), \
|
||||
(const C_TYPE*)gamma.data_ptr(), \
|
||||
(const C_TYPE*)beta.data_ptr(), \
|
||||
epsilon, \
|
||||
rows, \
|
||||
elems_per_row, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
void ds_post_layer_norm(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& residual,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon)
|
||||
{
|
||||
bool ragged_input = input.dim() == 2;
|
||||
|
||||
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
|
||||
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
|
||||
|
||||
DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float);
|
||||
DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
|
||||
if (input.options().dtype() == torch::T_TYPE) { \
|
||||
launch_fused_pre_ln((C_TYPE*)norm_output.data_ptr(), \
|
||||
(C_TYPE*)res_output.data_ptr(), \
|
||||
(const C_TYPE*)input.data_ptr(), \
|
||||
(const C_TYPE*)residual.data_ptr(), \
|
||||
(const C_TYPE*)gamma.data_ptr(), \
|
||||
(const C_TYPE*)beta.data_ptr(), \
|
||||
epsilon, \
|
||||
rows, \
|
||||
elems_per_row, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
void ds_pre_layer_norm(at::Tensor& res_output,
|
||||
at::Tensor& norm_output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& residual,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon)
|
||||
{
|
||||
bool ragged_input = input.dim() == 2;
|
||||
|
||||
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
|
||||
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
|
||||
|
||||
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float);
|
||||
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,490 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "conversion_utils.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "memory_access_utils.h"
|
||||
#include "reduction_utils.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
using rop = reduce::ROpType;
|
||||
|
||||
namespace ln {
|
||||
constexpr int granularity = 16;
|
||||
} // namespace ln
|
||||
|
||||
/*
|
||||
Regular layer norm implementation. Assumes elems_per_row % 8
|
||||
is equal to 0.
|
||||
|
||||
Args:
|
||||
output: buffer for output data
|
||||
vals: buffer for input data
|
||||
gamma: gain for normalization
|
||||
beta: bias for normalization
|
||||
epsilon: numeric stability
|
||||
elems_per_row: number of elements each block will normalize
|
||||
*/
|
||||
template <typename T, int unRoll, int threadsPerGroup, int maxThreads>
|
||||
__global__ void fused_ln(T* output,
|
||||
const T* vals,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int elems_per_row)
|
||||
{
|
||||
constexpr int T_per_load = ln::granularity / sizeof(T);
|
||||
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
|
||||
// X-dimension of the block
|
||||
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
|
||||
(tb.thread_index().y * elems_per_row);
|
||||
const int thread_offset = tb.thread_index().x * T_per_load;
|
||||
const int base_offset = block_offset + thread_offset;
|
||||
const int stride = blockDim.x * T_per_load;
|
||||
|
||||
float sum = reduce::init<rop::Add, float>();
|
||||
|
||||
const T* input_base = vals + base_offset;
|
||||
|
||||
T local_buffer[unRoll * T_per_load];
|
||||
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
|
||||
mem_access::load_global<ln::granularity>(
|
||||
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
|
||||
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
|
||||
sum = reduce::element<rop::Add>(sum, vals_up_cast);
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
|
||||
const float mean = sum / elems_per_row;
|
||||
|
||||
float mean_diff = reduce::init<rop::Add, float>();
|
||||
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
// Using a 0 value here skews the variance, have to if-guard
|
||||
if (thread_offset + i * stride < elems_per_row) {
|
||||
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
|
||||
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
|
||||
const float variance = mean_diff / elems_per_row;
|
||||
const float denom = __frsqrt_rn(variance + epsilon);
|
||||
|
||||
T* block_output = output + block_offset;
|
||||
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
const int iter_idx = i * stride + thread_offset;
|
||||
const bool do_loads = iter_idx < elems_per_row;
|
||||
|
||||
T gamma_local[T_per_load], beta_local[T_per_load];
|
||||
|
||||
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
|
||||
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
|
||||
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
float val = conversion::to<float>(iteration_buffer[j]);
|
||||
val = (val - mean) * denom;
|
||||
val =
|
||||
val * conversion::to<float>(gamma_local[j]) + conversion::to<float>(beta_local[j]);
|
||||
iteration_buffer[j] = conversion::to<T>(val);
|
||||
}
|
||||
|
||||
if (do_loads) {
|
||||
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \
|
||||
fused_ln<T, unRollFactor, threadsPerGroup, maxThreads> \
|
||||
<<<grid, block, 0, stream>>>(output, vals, gamma, beta, epsilon, elems_per_row);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_ln(T* output,
|
||||
const T* vals,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// 8 for __half, 4 for float
|
||||
constexpr int T_per_load = ln::granularity / sizeof(T);
|
||||
|
||||
constexpr int maxThreads = 256;
|
||||
|
||||
// For Flaoat, unRoll 4, for __half, unRoll 2
|
||||
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
|
||||
|
||||
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
|
||||
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
|
||||
|
||||
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
|
||||
// warp-sized blocks rather than stepping up to 64/96 threads
|
||||
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
|
||||
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
|
||||
|
||||
const int groups_per_block_max =
|
||||
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
|
||||
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
|
||||
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
|
||||
|
||||
dim3 block(threadsPerGroup, groups_per_block);
|
||||
dim3 grid(groups_launch);
|
||||
|
||||
const int elems_per_step = threadsPerGroup * h_per_step;
|
||||
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
|
||||
|
||||
if (is_subblock_schedule) {
|
||||
// <=128
|
||||
if (threadsPerGroup == 1) {
|
||||
LAUNCH_FUSED_LN(1, 1, maxThreads);
|
||||
} else if (threadsPerGroup == 2) {
|
||||
LAUNCH_FUSED_LN(1, 2, maxThreads);
|
||||
} else if (threadsPerGroup == 4) {
|
||||
LAUNCH_FUSED_LN(1, 4, maxThreads);
|
||||
} else if (threadsPerGroup == 8) {
|
||||
LAUNCH_FUSED_LN(1, 8, maxThreads);
|
||||
} else if (threadsPerGroup == 16) {
|
||||
LAUNCH_FUSED_LN(1, 16, maxThreads);
|
||||
}
|
||||
} else if (external_unRoll == 1) {
|
||||
// 129 - 4096 elems
|
||||
// (this can launch with 1-7 warps as well)
|
||||
LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 2) {
|
||||
// 4097 - 8192 elems
|
||||
LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 3) {
|
||||
// 8193 - 12288 elems
|
||||
LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 4) {
|
||||
// 12289 - 16384 elems
|
||||
LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FUSED_LN(T) \
|
||||
template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
|
||||
|
||||
INSTANTIATE_FUSED_LN(__half);
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_FUSED_LN(__nv_bfloat16);
|
||||
#endif
|
||||
INSTANTIATE_FUSED_LN(float);
|
||||
|
||||
/*
|
||||
Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
|
||||
is equal to 0.
|
||||
|
||||
TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
|
||||
need to be fused into compute-bound producer operations.
|
||||
|
||||
Args:
|
||||
output: buffer for output data
|
||||
res_output: output of residual addition
|
||||
vals: buffer for input data
|
||||
residual: residual data
|
||||
bias: bias of of input data
|
||||
gamma: gain for normalization
|
||||
beta: bias for normalization
|
||||
epsilon: numeric stability
|
||||
elems_per_row: number of elements each block will normalize
|
||||
Template arg:
|
||||
StoreResidual: controls whether the residual calculation is stored
|
||||
or not. When set to false, the input `res_output` is unused.
|
||||
*/
|
||||
template <typename T, int unRoll, int threadsPerGroup, int maxThreads, bool preLnResidual>
|
||||
__global__ void fused_residual_ln(T* output,
|
||||
T* res_output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int elems_per_row)
|
||||
{
|
||||
constexpr int T_per_load = ln::granularity / sizeof(T);
|
||||
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
|
||||
// X-dimension of the block
|
||||
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
|
||||
(tb.thread_index().y * elems_per_row);
|
||||
const int thread_offset = tb.thread_index().x * T_per_load;
|
||||
const int base_offset = block_offset + thread_offset;
|
||||
const int stride = tb.size() * T_per_load;
|
||||
|
||||
float sum = reduce::init<rop::Add, float>();
|
||||
|
||||
const T* input_base = vals + base_offset;
|
||||
const T* residual_base = residual + base_offset;
|
||||
|
||||
T local_buffer[unRoll * T_per_load];
|
||||
|
||||
// Unlike a vanilla layernorm, since we're fusing the two adds as well
|
||||
// an inner unRoll seems to be less valuable. If anything, a double unRoll
|
||||
// makes the most sense if we find we are having performance issues.
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
T residual_buffer[T_per_load];
|
||||
T bias_buffer[T_per_load];
|
||||
|
||||
mem_access::load_global<ln::granularity>(
|
||||
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
|
||||
mem_access::load_global<ln::granularity>(residual_buffer,
|
||||
residual_base + i * stride,
|
||||
thread_offset + i * stride < elems_per_row);
|
||||
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
|
||||
float res_up_cast = conversion::to<float>(residual_buffer[j]);
|
||||
vals_up_cast += res_up_cast;
|
||||
sum = reduce::element<rop::Add>(sum, vals_up_cast);
|
||||
iteration_buffer[j] = conversion::to<T>(vals_up_cast);
|
||||
}
|
||||
|
||||
if (preLnResidual && (thread_offset + i * stride < elems_per_row)) {
|
||||
mem_access::store_global<ln::granularity>(res_output + base_offset + i * stride,
|
||||
iteration_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
|
||||
const float mean = sum / elems_per_row;
|
||||
|
||||
float mean_diff = reduce::init<rop::Add, float>();
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
// Using a 0 value here skews the variance, have to if-guard
|
||||
if (thread_offset + i * stride < elems_per_row) {
|
||||
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
|
||||
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
|
||||
const float variance = mean_diff / elems_per_row;
|
||||
const float denom = __frsqrt_rn(variance + epsilon);
|
||||
|
||||
T* block_output = output + block_offset;
|
||||
|
||||
#pragma unRoll
|
||||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
const int iter_idx = i * stride + thread_offset;
|
||||
const bool do_loads = iter_idx < elems_per_row;
|
||||
|
||||
T gamma_local[T_per_load], beta_local[T_per_load];
|
||||
|
||||
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
|
||||
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
|
||||
|
||||
#pragma unRoll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
float val = conversion::to<float>(iteration_buffer[j]);
|
||||
val = (val - mean) * denom;
|
||||
val =
|
||||
val * conversion::to<float>(gamma_local[j]) + conversion::to<float>(beta_local[j]);
|
||||
iteration_buffer[j] = conversion::to<T>(val);
|
||||
}
|
||||
|
||||
if (do_loads) {
|
||||
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
|
||||
#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \
|
||||
fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, false> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
output, nullptr, vals, residual, gamma, beta, epsilon, elems_per_row);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_post_ln(T* output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// 8 for __half, 4 for float
|
||||
constexpr int T_per_load = ln::granularity / sizeof(T);
|
||||
|
||||
constexpr int maxThreads = 256;
|
||||
|
||||
// For Flaoat, unRoll 4, for __half, unRoll 2
|
||||
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
|
||||
|
||||
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
|
||||
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
|
||||
|
||||
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
|
||||
// warp-sized blocks rather than stepping up to 64/96 threads
|
||||
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
|
||||
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
|
||||
|
||||
const int groups_per_block_max =
|
||||
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
|
||||
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
|
||||
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
|
||||
|
||||
dim3 block(threadsPerGroup, groups_per_block);
|
||||
dim3 grid(groups_launch);
|
||||
|
||||
const int elems_per_step = threadsPerGroup * h_per_step;
|
||||
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
|
||||
|
||||
if (is_subblock_schedule) {
|
||||
// <=128
|
||||
if (threadsPerGroup == 1) {
|
||||
LAUNCH_FUSED_RES_LN(1, 1, maxThreads);
|
||||
} else if (threadsPerGroup == 2) {
|
||||
LAUNCH_FUSED_RES_LN(1, 2, maxThreads);
|
||||
} else if (threadsPerGroup == 4) {
|
||||
LAUNCH_FUSED_RES_LN(1, 4, maxThreads);
|
||||
} else if (threadsPerGroup == 8) {
|
||||
LAUNCH_FUSED_RES_LN(1, 8, maxThreads);
|
||||
} else if (threadsPerGroup == 16) {
|
||||
LAUNCH_FUSED_RES_LN(1, 16, maxThreads);
|
||||
}
|
||||
} else if (external_unRoll == 1) {
|
||||
// 129 - 4096 elems
|
||||
// (this can launch with 1-7 warps as well)
|
||||
LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 2) {
|
||||
// 4097 - 8192 elems
|
||||
LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 3) {
|
||||
// 8193 - 12288 elems
|
||||
LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 4) {
|
||||
// 12289 - 16384 elems
|
||||
LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \
|
||||
fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, true> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
norm_output, res_output, vals, residual, gamma, beta, epsilon, elems_per_row);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_pre_ln(T* norm_output,
|
||||
T* res_output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// 8 for __half, 4 for float
|
||||
constexpr int T_per_load = ln::granularity / sizeof(T);
|
||||
|
||||
constexpr int maxThreads = 256;
|
||||
|
||||
// For Flaoat, unRoll 4, for __half, unRoll 2
|
||||
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
|
||||
|
||||
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
|
||||
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
|
||||
|
||||
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
|
||||
// warp-sized blocks rather than stepping up to 64/96 threads
|
||||
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
|
||||
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
|
||||
|
||||
const int groups_per_block_max =
|
||||
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
|
||||
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
|
||||
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
|
||||
|
||||
dim3 block(threadsPerGroup, groups_per_block);
|
||||
dim3 grid(groups_launch);
|
||||
|
||||
const int elems_per_step = threadsPerGroup * h_per_step;
|
||||
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
|
||||
|
||||
if (is_subblock_schedule) {
|
||||
// <=128
|
||||
if (threadsPerGroup == 1) {
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads);
|
||||
} else if (threadsPerGroup == 2) {
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads);
|
||||
} else if (threadsPerGroup == 4) {
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads);
|
||||
} else if (threadsPerGroup == 8) {
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads);
|
||||
} else if (threadsPerGroup == 16) {
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads);
|
||||
}
|
||||
} else if (external_unRoll == 1) {
|
||||
// 129 - 4096 elems
|
||||
// (this can launch with 1-7 warps as well)
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 2) {
|
||||
// 4097 - 8192 elems
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 3) {
|
||||
// 8193 - 12288 elems
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 4) {
|
||||
// 12289 - 16384 elems
|
||||
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_RES_LN(T) \
|
||||
template void launch_fused_post_ln<T>( \
|
||||
T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
|
||||
|
||||
#define INSTANTIATE_PRE_LN_RES(T) \
|
||||
template void launch_fused_pre_ln<T>( \
|
||||
T*, T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
|
||||
|
||||
INSTANTIATE_RES_LN(__half);
|
||||
INSTANTIATE_RES_LN(float);
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_RES_LN(__nv_bfloat16);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_PRE_LN_RES(__half);
|
||||
INSTANTIATE_PRE_LN_RES(float);
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_PRE_LN_RES(__nv_bfloat16);
|
||||
#endif
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "ds_kernel_utils.h"
|
||||
|
||||
/*
|
||||
Kernel launch methods for layer norm variants.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_ln(T* output,
|
||||
const T* vals,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_post_ln(T* output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream);
|
||||
template <typename T>
|
||||
void launch_fused_pre_ln(T* norm_output,
|
||||
T* res_output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream);
|
||||
|
||||
void ds_layer_norm(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon);
|
||||
|
||||
void ds_post_layer_norm(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& residual,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon);
|
||||
|
||||
void ds_pre_layer_norm(at::Tensor& res_output,
|
||||
at::Tensor& norm_output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& residual,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
float epsilon);
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .rms_norm import CUDARMSNorm
|
||||
from .rms_pre_norm import CUDARMSPreNorm
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "rms_norm.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using scalar_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kBFloat16) { \
|
||||
using scalar_t = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#else
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using scalar_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& norm_output,
|
||||
torch::Tensor& norm_input,
|
||||
torch::Tensor& gamma,
|
||||
float epsilon)
|
||||
{
|
||||
TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
|
||||
"norm_output and norm_input should have the same data type");
|
||||
TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
|
||||
"norm_output and gamma should have the same data type");
|
||||
|
||||
const int32_t rows = norm_input.size(0);
|
||||
const int32_t cols = norm_input.size(1);
|
||||
|
||||
TORCH_CHECK(norm_output.size(0) == rows,
|
||||
"norm_output and norm_input should have the same first dimension");
|
||||
TORCH_CHECK(norm_output.size(1) == cols,
|
||||
"norm_output and norm_input should have the same second dimension");
|
||||
|
||||
DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
|
||||
scalar_t* norm_output_ptr = reinterpret_cast<scalar_t*>(norm_output.data_ptr());
|
||||
scalar_t* norm_input_ptr = reinterpret_cast<scalar_t*>(norm_input.data_ptr());
|
||||
scalar_t* gamma_ptr = reinterpret_cast<scalar_t*>(gamma.data_ptr());
|
||||
scalar_t* null_t = nullptr;
|
||||
|
||||
launch_rms_norm(norm_output_ptr,
|
||||
null_t,
|
||||
norm_input_ptr,
|
||||
null_t,
|
||||
gamma_ptr,
|
||||
epsilon,
|
||||
rows,
|
||||
cols,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
}
|
||||
|
||||
void rms_pre_norm(torch::Tensor& norm_output,
|
||||
torch::Tensor& residual_output,
|
||||
torch::Tensor& norm_input,
|
||||
torch::Tensor& residual_input,
|
||||
torch::Tensor& gamma,
|
||||
float epsilon)
|
||||
{
|
||||
TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
|
||||
"norm_output and norm_input should have the same data type");
|
||||
TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
|
||||
"norm_output and gamma should have the same data type");
|
||||
|
||||
const int32_t rows = norm_input.size(0);
|
||||
const int32_t cols = norm_input.size(1);
|
||||
|
||||
TORCH_CHECK(norm_output.size(0) == rows,
|
||||
"norm_output and norm_input should have the same first dimension");
|
||||
TORCH_CHECK(norm_output.size(1) == cols,
|
||||
"norm_output and norm_input should have the same second dimension");
|
||||
|
||||
TORCH_CHECK(residual_output.size(0) == rows,
|
||||
"residual_output and norm_input should have the same first dimension");
|
||||
TORCH_CHECK(residual_output.size(1) == cols,
|
||||
"residual_output and norm_input should have the same second dimension");
|
||||
|
||||
TORCH_CHECK(residual_input.size(0) == rows,
|
||||
"residual_input and norm_input should have the same first dimension");
|
||||
TORCH_CHECK(residual_input.size(1) == cols,
|
||||
"residual_input and norm_input should have the same second dimension");
|
||||
|
||||
DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
|
||||
scalar_t* norm_output_ptr = reinterpret_cast<scalar_t*>(norm_output.data_ptr());
|
||||
scalar_t* residual_output_ptr = reinterpret_cast<scalar_t*>(residual_output.data_ptr());
|
||||
const scalar_t* norm_input_ptr = reinterpret_cast<const scalar_t*>(norm_input.data_ptr());
|
||||
const scalar_t* residual_input_ptr =
|
||||
reinterpret_cast<const scalar_t*>(residual_input.data_ptr());
|
||||
const scalar_t* gamma_ptr = reinterpret_cast<const scalar_t*>(gamma.data_ptr());
|
||||
|
||||
launch_rms_norm(norm_output_ptr,
|
||||
residual_output_ptr,
|
||||
norm_input_ptr,
|
||||
residual_input_ptr,
|
||||
gamma_ptr,
|
||||
epsilon,
|
||||
rows,
|
||||
cols,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
}
|
|
@ -0,0 +1,262 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "conversion_utils.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "memory_access_utils.h"
|
||||
#include "reduction_utils.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
using rop = reduce::ROpType;
|
||||
|
||||
namespace rms {
|
||||
constexpr int granularity = 16;
|
||||
} // namespace rms
|
||||
|
||||
template <typename T, int UNROLL, int threadsPerGroup, int maxThreads>
|
||||
__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row)
|
||||
{
|
||||
constexpr int T_per_load = rms::granularity / sizeof(T);
|
||||
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
|
||||
// X-dimension of the block
|
||||
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
|
||||
(tb.thread_index().y * elems_per_row);
|
||||
const int thread_offset = tb.thread_index().x * T_per_load;
|
||||
const int base_offset = block_offset + thread_offset;
|
||||
const int stride = blockDim.x * T_per_load;
|
||||
|
||||
float var_sum = reduce::init<rop::Add, float>();
|
||||
|
||||
const T* input_base = vals + base_offset;
|
||||
|
||||
T local_buffer[UNROLL * T_per_load];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL; i++) {
|
||||
T* iteration_buffer = local_buffer + (i * T_per_load);
|
||||
|
||||
mem_access::load_global<rms::granularity>(iteration_buffer,
|
||||
input_base + (i * stride),
|
||||
thread_offset + (i * stride) < elems_per_row);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
float up_cast = conversion::to<float>(iteration_buffer[j]);
|
||||
float sq_val = up_cast * up_cast;
|
||||
var_sum = reduce::element<rop::Add, float>(var_sum, sq_val);
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, var_sum);
|
||||
const float var = var_sum / elems_per_row;
|
||||
const T denom = conversion::to<T>(__frsqrt_rn(var + epsilon));
|
||||
|
||||
T* block_output = output + block_offset;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL; i++) {
|
||||
T* iteration_buffer = local_buffer + (i * T_per_load);
|
||||
const int iter_idx = i * stride + thread_offset;
|
||||
const bool do_loads = (iter_idx < elems_per_row);
|
||||
|
||||
T gamma_local[T_per_load];
|
||||
|
||||
mem_access::load_global<rms::granularity>(gamma_local, gamma + iter_idx, do_loads);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
iteration_buffer[j] *= denom;
|
||||
iteration_buffer[j] *= gamma_local[j];
|
||||
}
|
||||
|
||||
if (do_loads) {
|
||||
mem_access::store_global<rms::granularity>(block_output + iter_idx, iteration_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int UNROLL, int threadsPerGroup, int maxThreads>
|
||||
__global__ void pre_rms_norm(T* output,
|
||||
T* res_out,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
float epsilon,
|
||||
int elems_per_row)
|
||||
{
|
||||
constexpr int T_per_load = rms::granularity / sizeof(T);
|
||||
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
|
||||
// X-dimension of the block
|
||||
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
|
||||
(tb.thread_index().y * elems_per_row);
|
||||
const int thread_offset = tb.thread_index().x * T_per_load;
|
||||
const int base_offset = block_offset + thread_offset;
|
||||
const int stride = blockDim.x * T_per_load;
|
||||
|
||||
float var_sum = reduce::init<rop::Add, float>();
|
||||
|
||||
const T* input_base = vals + base_offset;
|
||||
const T* residual_base = residual + base_offset;
|
||||
T* res_output = res_out + base_offset;
|
||||
|
||||
T local_buffer[UNROLL * T_per_load];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL; i++) {
|
||||
T* iteration_buffer = local_buffer + (i * T_per_load);
|
||||
T residual_buffer[T_per_load];
|
||||
|
||||
const int iter_offset = i * stride + thread_offset;
|
||||
const bool do_loads = (iter_offset < elems_per_row);
|
||||
|
||||
mem_access::load_global<rms::granularity>(
|
||||
iteration_buffer, input_base + (i * stride), do_loads);
|
||||
mem_access::load_global<rms::granularity>(
|
||||
residual_buffer, residual_base + (i * stride), do_loads);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
iteration_buffer[j] += residual_buffer[j];
|
||||
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
|
||||
|
||||
var_sum = reduce::element<rop::Add, float>(var_sum, vals_up_cast * vals_up_cast);
|
||||
}
|
||||
|
||||
if (do_loads) {
|
||||
mem_access::store_global<rms::granularity>(res_output + i * stride, iteration_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, var_sum);
|
||||
const float var = var_sum / elems_per_row;
|
||||
const T denom = conversion::to<T>(__frsqrt_rn(var + epsilon));
|
||||
|
||||
T* block_output = output + block_offset;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL; i++) {
|
||||
T* iteration_buffer = local_buffer + (i * T_per_load);
|
||||
const int iter_idx = i * stride + thread_offset;
|
||||
const bool do_loads = (iter_idx < elems_per_row);
|
||||
|
||||
T gamma_local[T_per_load];
|
||||
|
||||
mem_access::load_global<rms::granularity>(gamma_local, gamma + iter_idx, do_loads);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_per_load; j++) {
|
||||
iteration_buffer[j] *= denom;
|
||||
iteration_buffer[j] *= gamma_local[j];
|
||||
}
|
||||
|
||||
if (do_loads) {
|
||||
mem_access::store_global<rms::granularity>(block_output + iter_idx, iteration_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
|
||||
rms_norm<T, UNROLL, threadsPerGroup, maxThreads> \
|
||||
<<<grid, block, 0, stream>>>(norm_output, vals, gamma, epsilon, elems_per_row);
|
||||
|
||||
#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
|
||||
pre_rms_norm<T, UNROLL, threadsPerGroup, maxThreads><<<grid, block, 0, stream>>>( \
|
||||
norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row);
|
||||
|
||||
#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
|
||||
if (pre_norm) { \
|
||||
LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
|
||||
} else { \
|
||||
LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_rms_norm(T* norm_output,
|
||||
T* res_output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// 8 for __half, 4 for float
|
||||
constexpr int T_per_load = rms::granularity / sizeof(T);
|
||||
constexpr int maxThreads = 256;
|
||||
constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2;
|
||||
|
||||
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
|
||||
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll;
|
||||
|
||||
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
|
||||
// warp-sized blocks rather than stepping up to 64/96 threads
|
||||
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
|
||||
const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
|
||||
|
||||
const int groups_per_block_max =
|
||||
is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1;
|
||||
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
|
||||
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
|
||||
|
||||
dim3 block(threads_per_group, groups_per_block);
|
||||
dim3 grid(groups_launch);
|
||||
|
||||
const int elems_per_step = threads_per_group * h_per_step;
|
||||
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
|
||||
|
||||
bool pre_norm = (residual == nullptr) ? false : true;
|
||||
|
||||
if (is_subblock_schedule) {
|
||||
// <=128
|
||||
if (threads_per_group == 1) {
|
||||
LAUNCH_ALL_RMS_NORM(1, 1, maxThreads);
|
||||
} else if (threads_per_group == 2) {
|
||||
LAUNCH_ALL_RMS_NORM(1, 2, maxThreads);
|
||||
} else if (threads_per_group == 4) {
|
||||
LAUNCH_ALL_RMS_NORM(1, 4, maxThreads);
|
||||
} else if (threads_per_group == 8) {
|
||||
LAUNCH_ALL_RMS_NORM(1, 8, maxThreads);
|
||||
} else if (threads_per_group == 16) {
|
||||
LAUNCH_ALL_RMS_NORM(1, 16, maxThreads);
|
||||
}
|
||||
} else if (external_unRoll == 1) {
|
||||
// 129 - 4096 elems
|
||||
// (this can launch with 1-7 warps as well)
|
||||
LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 2) {
|
||||
// 4097 - 8192 elems
|
||||
LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 3) {
|
||||
// 8193 - 12288 elems
|
||||
LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads);
|
||||
} else if (external_unRoll == 4) {
|
||||
// 12289 - 16384 elems
|
||||
LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_LAUNCH_RMS_NORM(T) \
|
||||
template void launch_rms_norm<T>(T * norm_output, \
|
||||
T * res_output, \
|
||||
const T* vals, \
|
||||
const T* residual, \
|
||||
const T* gamma, \
|
||||
float epsilon, \
|
||||
int rows, \
|
||||
int elems_per_row, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_LAUNCH_RMS_NORM(float)
|
||||
INSTANTIATE_LAUNCH_RMS_NORM(__half)
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16)
|
||||
#endif
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "ds_kernel_utils.h"
|
||||
|
||||
template <typename T>
|
||||
void launch_rms_norm(T* norm_output,
|
||||
T* res_output,
|
||||
const T* vals,
|
||||
const T* residual,
|
||||
const T* gamma,
|
||||
float epsilon,
|
||||
int rows,
|
||||
int elems_per_row,
|
||||
cudaStream_t stream);
|
||||
|
||||
void rms_norm(torch::Tensor& norm_output,
|
||||
torch::Tensor& norm_input,
|
||||
torch::Tensor& gamma,
|
||||
float epsilon);
|
||||
|
||||
void rms_pre_norm(torch::Tensor& norm_output,
|
||||
torch::Tensor& residual_output,
|
||||
torch::Tensor& norm_input,
|
||||
torch::Tensor& residual_input,
|
||||
torch::Tensor& gamma,
|
||||
float epsilon);
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from .rms_norm_base import CUDARMSNormBase
|
||||
|
||||
|
||||
class CUDARMSNorm(CUDARMSNormBase):
|
||||
"""
|
||||
Floating point layer norm kernel for CUDA/RoCM.
|
||||
|
||||
Performs: z = ln(x)
|
||||
"""
|
||||
|
||||
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
output_z may alias input_x directly. All Tensors should have the same shape.
|
||||
|
||||
Parameters:
|
||||
output_z (torch.Tensor): Output tensor.
|
||||
input_x (torch.Tensor): Input tensor.
|
||||
gamma (torch.Tensor): Gamma tensor.
|
||||
"""
|
||||
self.inf_module.rms_norm(output_z, input_x, gamma, self.epsilon)
|
||||
return output_z
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import elem_size
|
||||
from deepspeed.ops.op_builder import InferenceCoreBuilder
|
||||
|
||||
|
||||
class CUDARMSNormBase(DSKernelBase):
|
||||
"""
|
||||
Base class for CUDA LN kernels. They all same the same validation logic,
|
||||
so we can share it here.
|
||||
"""
|
||||
|
||||
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
|
||||
"""
|
||||
Parameters:
|
||||
channels (int): Number of channels in the input tensor. Must be divisible to align
|
||||
to 16 bytes.
|
||||
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
|
||||
are torch.float16, torch.bfloat16, and torch.float32.
|
||||
"""
|
||||
if fp_dtype not in CUDARMSNormBase.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, CUDARMSNormBase.supported_dtypes))
|
||||
|
||||
if elem_size(fp_dtype) * channels % 16 != 0:
|
||||
raise ValueError("channels must be divisible by 16 bytes")
|
||||
|
||||
self.inf_module = InferenceCoreBuilder().load()
|
||||
self.epsilon = epsilon
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .rms_norm_base import CUDARMSNormBase
|
||||
|
||||
|
||||
class CUDARMSPreNorm(CUDARMSNormBase):
|
||||
"""
|
||||
Floating point pre-LayerNorm kernel for CUDA/RoCM.
|
||||
|
||||
Performs: z_res = x_res + y_hid
|
||||
z_hid = ln(z_hid)
|
||||
"""
|
||||
|
||||
def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
|
||||
gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
z_res can alias x_res. All non-parameter input/output tensors
|
||||
must have the same shape. z_hid can alias y_hid.
|
||||
|
||||
Parameters:
|
||||
z_res (torch.Tensor): Output residual.
|
||||
z_hid (torch.Tensor): Output hidden states.
|
||||
x_res (torch.Tensor): Input residual.
|
||||
y_hid (torch.Tensor): Input hidden states.
|
||||
gamma (torch.Tensor): Gamma tensor.
|
||||
beta (torch.Tensor): Beta tensor.
|
||||
|
||||
Returns:
|
||||
output (torch.Tensor): Output tensor.
|
||||
"""
|
||||
self.inf_module.rms_pre_norm(z_hid, z_res, y_hid, x_res, gamma, self.epsilon)
|
||||
return z_res, z_hid
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .gated_activation import *
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import ActivationType, elem_size
|
||||
from deepspeed.ops.op_builder import InferenceCoreBuilder
|
||||
|
||||
|
||||
class CUDAGatedActivation(DSKernelBase):
|
||||
"""
|
||||
CUDA implementation of gated activation kernel. This kernel assumes that the input
|
||||
tensor has gate and activation values in adjacent channels. The output tensor should
|
||||
have half the dimensionality of the input tensor.
|
||||
"""
|
||||
|
||||
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
supported_act_fns = [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
|
||||
|
||||
def __init__(self, channels: int, fp_dtype: torch.dtype, act_fn: ActivationType) -> None:
|
||||
"""
|
||||
Compile and validate for the gated activation function.
|
||||
|
||||
Args:
|
||||
channels (int): Number of columns in the output tensor. Must be divisible to align
|
||||
to 8 bytes.
|
||||
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
|
||||
are torch.float16, torch.bfloat16, and torch.float32.
|
||||
act_fn (ActivationType): Activation function to use. Only GEGLU is supported.
|
||||
"""
|
||||
if fp_dtype not in CUDAGatedActivation.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, CUDAGatedActivation.supported_dtypes))
|
||||
|
||||
act_fn = ActivationType(act_fn)
|
||||
if act_fn not in CUDAGatedActivation.supported_act_fns:
|
||||
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
|
||||
act_fn, CUDAGatedActivation.supported_act_fns))
|
||||
|
||||
if elem_size(fp_dtype) * channels % 8 != 0:
|
||||
raise ValueError("Channels must be divisible by 16 bytes")
|
||||
|
||||
if elem_size(fp_dtype) * channels > 98304:
|
||||
raise ValueError(
|
||||
"Kernel only compiled to support 98304 bytes per row, please file an issue if your model requires more."
|
||||
)
|
||||
|
||||
self.inf_module = InferenceCoreBuilder().load()
|
||||
self.act_fn = act_fn
|
||||
self.kernel = self.inf_module.gated_activation
|
||||
|
||||
def __call__(self, output: torch.Tensor, input: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Performs gated activation on the input tensor, writing the result to the output tensor.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor. Can be of [T, C // 2] or [B, S, C // 2]
|
||||
input (torch.Tensor): Input tensor. Can be of [T, C] or [B, S, C]
|
||||
"""
|
||||
self.kernel(output, input, bias, self.act_fn.value)
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "gated_activation_kernels.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using scalar_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kBFloat16) { \
|
||||
using scalar_t = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#else
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using scalar_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using scalar_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
void ds_gated_activation(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
c10::optional<torch::Tensor>& bias,
|
||||
int activation_type_raw)
|
||||
{
|
||||
bool ragged_input = input.dim() == 2;
|
||||
|
||||
const ActivationType activation_type = static_cast<ActivationType>(activation_type_raw);
|
||||
|
||||
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
|
||||
const int cols = ragged_input ? input.size(1) : input.size(2);
|
||||
|
||||
DISPATCH_FOR_FLOAT(input.scalar_type(), [&] {
|
||||
scalar_t* bias_ptr = nullptr;
|
||||
if (bias.has_value()) {
|
||||
TORCH_CHECK(bias.value().scalar_type() == input.scalar_type(),
|
||||
"Bias type must match input type");
|
||||
TORCH_CHECK(bias.value().numel() == cols,
|
||||
"Bias must have the same number of elements as the input channels");
|
||||
bias_ptr = reinterpret_cast<scalar_t*>(bias.value().data_ptr());
|
||||
}
|
||||
|
||||
scalar_t* output_ptr = reinterpret_cast<scalar_t*>(output.data_ptr());
|
||||
const scalar_t* input_ptr = reinterpret_cast<const scalar_t*>(input.data_ptr());
|
||||
|
||||
launch_gated_activation(output_ptr,
|
||||
input_ptr,
|
||||
bias_ptr,
|
||||
rows,
|
||||
cols,
|
||||
activation_type,
|
||||
c10::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <stdexcept>
|
||||
#include "activation_type.h"
|
||||
#include "conversion_utils.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "memory_access_utils.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace gated_act {
|
||||
|
||||
constexpr int access_size = 16;
|
||||
constexpr int threads = 1024;
|
||||
|
||||
template <ActivationType ActType>
|
||||
float gated_act_fn(float x, float y);
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float gated_act_fn<ActivationType::GEGLU>(float x, float y)
|
||||
{
|
||||
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
|
||||
constexpr float mul_param = 0.044715;
|
||||
return y * x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float gated_act_fn<ActivationType::ReGLU>(float x, float y)
|
||||
{
|
||||
return y * (x > 0.0f ? x : 0.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float gated_act_fn<ActivationType::SiGLU>(float x, float y)
|
||||
{
|
||||
return y * (x / (1.0f + expf(-x)));
|
||||
}
|
||||
|
||||
} // namespace gated_act
|
||||
|
||||
template <typename T, ActivationType ActType, int loopUnroll>
|
||||
__global__ void gated_activation_kernel(T* output,
|
||||
const T* input,
|
||||
const T* bias,
|
||||
int rows,
|
||||
int cols)
|
||||
{
|
||||
constexpr int read_vector = gated_act::access_size / sizeof(T);
|
||||
constexpr int write_vector = read_vector / 2;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x * read_vector;
|
||||
|
||||
const T* input_row = input + row * cols;
|
||||
T* output_row = output + row * cols / 2;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < loopUnroll; i++) {
|
||||
T read[read_vector];
|
||||
T bias_read[read_vector];
|
||||
T store[write_vector];
|
||||
|
||||
const int read_offset = col + gated_act::threads * read_vector * i;
|
||||
const int write_offset = col / 2 + gated_act::threads * write_vector * i;
|
||||
|
||||
if (i != loopUnroll - 1 || read_offset < cols) {
|
||||
mem_access::load_global<gated_act::access_size>(read, input_row + read_offset);
|
||||
mem_access::load_global<gated_act::access_size>(
|
||||
bias_read, bias + read_offset, bias != nullptr);
|
||||
|
||||
for (int j = 0; j < write_vector; j++) {
|
||||
float g_val =
|
||||
conversion::to<float>(read[j * 2]) + conversion::to<float>(bias_read[j * 2]);
|
||||
float a_val = conversion::to<float>(read[j * 2 + 1]) +
|
||||
conversion::to<float>(bias_read[j * 2 + 1]);
|
||||
|
||||
float act_val = gated_act::gated_act_fn<ActType>(g_val, a_val);
|
||||
store[j] = conversion::to<T>(act_val);
|
||||
}
|
||||
|
||||
mem_access::store_global<gated_act::access_size / 2>(output_row + write_offset, store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define DISPATCH_UNROLL(unroll_val) \
|
||||
gated_activation_kernel<T, ActType, unroll_val> \
|
||||
<<<grid, block, 0, stream>>>(output, input, bias, rows, cols);
|
||||
|
||||
template <typename T, ActivationType ActType>
|
||||
void launch_gated_activation_impl(T* output,
|
||||
const T* input,
|
||||
const T* bias,
|
||||
int rows,
|
||||
int cols,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int read_vector = gated_act::access_size / sizeof(T);
|
||||
constexpr int cols_per_unroll = gated_act::threads * read_vector;
|
||||
const int req_threads = (cols + read_vector - 1) / read_vector;
|
||||
const int threads = std::min(req_threads, gated_act::threads);
|
||||
|
||||
const dim3 grid(rows);
|
||||
const dim3 block(threads);
|
||||
const int unroll = (cols + cols_per_unroll - 1) / cols_per_unroll;
|
||||
|
||||
if (unroll == 1) {
|
||||
DISPATCH_UNROLL(1);
|
||||
} else if (unroll == 2) {
|
||||
DISPATCH_UNROLL(2);
|
||||
} else if (unroll == 3) {
|
||||
DISPATCH_UNROLL(3);
|
||||
} else if (unroll == 4) {
|
||||
DISPATCH_UNROLL(4);
|
||||
} else if (unroll == 5) {
|
||||
DISPATCH_UNROLL(5);
|
||||
} else if (unroll == 6) {
|
||||
DISPATCH_UNROLL(6);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Called with more columns than supported, please report this bug and this limit will "
|
||||
"be increased.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_gated_activation(T* output,
|
||||
const T* input,
|
||||
const T* bias,
|
||||
int rows,
|
||||
int cols,
|
||||
ActivationType act_type,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
switch (act_type) {
|
||||
case ActivationType::GEGLU:
|
||||
launch_gated_activation_impl<T, ActivationType::GEGLU>(
|
||||
output, input, bias, rows, cols, stream);
|
||||
break;
|
||||
case ActivationType::ReGLU:
|
||||
launch_gated_activation_impl<T, ActivationType::ReGLU>(
|
||||
output, input, bias, rows, cols, stream);
|
||||
break;
|
||||
case ActivationType::SiGLU:
|
||||
launch_gated_activation_impl<T, ActivationType::SiGLU>(
|
||||
output, input, bias, rows, cols, stream);
|
||||
break;
|
||||
default: throw std::runtime_error("Unsupported activation type");
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_TYPE(T) \
|
||||
template void launch_gated_activation<T>(T * output, \
|
||||
const T* input, \
|
||||
const T* bias, \
|
||||
int rows, \
|
||||
int cols, \
|
||||
ActivationType act_type, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_FOR_TYPE(float)
|
||||
INSTANTIATE_FOR_TYPE(__half)
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_FOR_TYPE(__nv_bfloat16)
|
||||
#endif
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "activation_type.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
|
||||
template <typename T>
|
||||
void launch_gated_activation(T* output,
|
||||
const T* vals,
|
||||
const T* bias,
|
||||
int rows,
|
||||
int cols,
|
||||
ActivationType activation_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
void ds_gated_activation(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
c10::optional<torch::Tensor>& bias,
|
||||
int activation_type_raw);
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .mixed_gemm import *
|
||||
from .moe_gemm import *
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "mixed_gemm.h"
|
||||
#include "moe_gemm.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
// mixed_gemm.h
|
||||
m.def("mixed_gemm", &mixed_gemm, "Mixed-precision GEMM");
|
||||
|
||||
// moe_gemm.h
|
||||
m.def("moe_gemm", &moe_gemm, "MultiGEMM for MoE (16-bit weights)");
|
||||
m.def("mixed_moe_gemm", &mixed_moe_gemm, "MultiGEMM for MoE (4-bit/8-bit weights)");
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .mixed_gemm import *
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include "mixed_gemm.h"
|
||||
#include "mixed_gemm_api.h"
|
||||
#include "weight_variant.h"
|
||||
|
||||
// Switch helpers inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#define ACT_DTYPE_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using ActivationDtype = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using ActivationDtype = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define WEIGHT_VARIANT_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kFP8; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kFP4; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void mixed_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& scales,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
int num_bits,
|
||||
int activation_raw)
|
||||
{
|
||||
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
|
||||
"Output and hidden states must have the same dtype");
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "Data width must be 4 or 8");
|
||||
TORCH_CHECK(output.size(0) == hidden_states.size(0), "Token dimension mismatch");
|
||||
|
||||
int32_t m = output.size(0);
|
||||
int32_t k = hidden_states.size(1);
|
||||
int32_t n = weight.size(1);
|
||||
|
||||
TORCH_CHECK(weight.size(0) == k, "Weight dimension mismatch");
|
||||
|
||||
ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
|
||||
WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] {
|
||||
fastertransformer::CutlassFpAIntBGemmRunner<ActivationDtype, WVariant> runner =
|
||||
*MixedGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
|
||||
|
||||
ActivationType activation_type = (ActivationType)activation_raw;
|
||||
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
|
||||
runner.gemm((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(const char*)weight.data_ptr(),
|
||||
(ActivationDtype*)scales.data_ptr(),
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
nullptr,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
} else {
|
||||
ActivationDtype* bias_ptr = nullptr;
|
||||
if (bias.has_value()) { bias_ptr = (ActivationDtype*)bias.value().data_ptr(); }
|
||||
runner.gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(char*)weight.data_ptr(),
|
||||
(ActivationDtype*)scales.data_ptr(),
|
||||
bias_ptr,
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
activation_type,
|
||||
nullptr,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void mixed_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& scales,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
int num_bits,
|
||||
int activation_raw);
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import ActivationType, DtypeEnum
|
||||
from deepspeed.ops.op_builder import InferenceCutlassBuilder
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MixedGEMM(DSKernelBase):
|
||||
"""
|
||||
CUTLASS implementation of MoE GEMM.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
|
||||
|
||||
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None:
|
||||
|
||||
if not isinstance(fp_dtype, DtypeEnum):
|
||||
fp_dtype = DtypeEnum(fp_dtype)
|
||||
|
||||
if fp_dtype not in MixedGEMM.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, MixedGEMM.supported_dtypes))
|
||||
|
||||
if act_fn not in MixedGEMM.supported_act_fns:
|
||||
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
|
||||
act_fn, MixedGEMM.supported_act_fns))
|
||||
|
||||
if num_bits != 4 and num_bits != 8:
|
||||
raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits))
|
||||
|
||||
inf_module = InferenceCutlassBuilder().load()
|
||||
self.num_bits = num_bits
|
||||
self.kernel = inf_module.moe_gemm
|
||||
self.act_fn = act_fn
|
||||
|
||||
def __call__(self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
biases: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
|
||||
|
||||
Arguments:
|
||||
output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
|
||||
hidden_states (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
|
||||
weights (torch.Tensor): The weights of shape [in_neurons, out_neurons]. These weights must be contiguous.
|
||||
scales (torch.Tensor): The scales of shape [out_neurons]. These scales must be contiguous.
|
||||
biases (torch.Tensor): The biases of shape [out_neurons]. These biases must be contiguous.
|
||||
|
||||
Returns:
|
||||
output
|
||||
"""
|
||||
self.kernel(output, hidden_states, weights, biases, self.num_bits, self.act_fn)
|
||||
return output
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "activation_type.h"
|
||||
#include "weight_variant.h"
|
||||
|
||||
namespace fastertransformer {
|
||||
|
||||
template <typename T, WeightVariant V>
|
||||
class CutlassFpAIntBGemmRunner {
|
||||
public:
|
||||
void gemm(const T* A,
|
||||
const char* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
char* workspace_ptr,
|
||||
const size_t workspace_bytes,
|
||||
cudaStream_t stream);
|
||||
|
||||
void gemm_bias_act(const T* A,
|
||||
const char* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
ActivationType activation_type,
|
||||
char* workspace_ptr,
|
||||
const size_t workspace_bytes,
|
||||
cudaStream_t stream);
|
||||
};
|
||||
|
||||
} // namespace fastertransformer
|
||||
|
||||
template <typename T, WeightVariant V>
|
||||
class MixedGemmContext {
|
||||
public:
|
||||
MixedGemmContext() { _runner = new fastertransformer::CutlassFpAIntBGemmRunner<T, V>(); }
|
||||
|
||||
virtual ~MixedGemmContext() { delete _runner; }
|
||||
|
||||
static MixedGemmContext& Instance()
|
||||
{
|
||||
static MixedGemmContext _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
fastertransformer::CutlassFpAIntBGemmRunner<T, V>* GeMM_Runner() const { return _runner; }
|
||||
|
||||
fastertransformer::CutlassFpAIntBGemmRunner<T, V>* _runner;
|
||||
};
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .mixed_moe_gemm import *
|
||||
from .moe_gemm import *
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import ActivationType, DtypeEnum
|
||||
from deepspeed.ops.op_builder import InferenceCutlassBuilder
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MixedMoEGEMM(DSKernelBase):
|
||||
"""
|
||||
CUTLASS implementation of MoE GEMM.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
|
||||
|
||||
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None:
|
||||
|
||||
if not isinstance(fp_dtype, DtypeEnum):
|
||||
fp_dtype = DtypeEnum(fp_dtype)
|
||||
|
||||
if fp_dtype not in MixedMoEGEMM.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, MixedMoEGEMM.supported_dtypes))
|
||||
|
||||
if act_fn not in MixedMoEGEMM.supported_act_fns:
|
||||
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
|
||||
act_fn, MixedMoEGEMM.supported_act_fns))
|
||||
|
||||
if num_bits != 4 and num_bits != 8:
|
||||
raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits))
|
||||
|
||||
inf_module = InferenceCutlassBuilder().load()
|
||||
self.num_bits = num_bits
|
||||
self.kernel = inf_module.moe_gemm
|
||||
self.act_fn = act_fn
|
||||
|
||||
def __call__(self,
|
||||
ordered_output: torch.Tensor,
|
||||
ordered_input: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
total_rows_before_expert: torch.Tensor,
|
||||
biases: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
|
||||
|
||||
Arguments:
|
||||
ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
|
||||
ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
|
||||
weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous.
|
||||
scales (torch.Tensor): The scales of shape [n_experts, out_neurons]. These scales must be contiguous.
|
||||
total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts].
|
||||
biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous.
|
||||
|
||||
Returns:
|
||||
ordered_output
|
||||
"""
|
||||
self.kernel(ordered_output, ordered_input, weights, scales, biases, total_rows_before_expert, self.num_bits,
|
||||
self.act_fn)
|
||||
return ordered_output
|
|
@ -0,0 +1,175 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include "moe_gemm.h"
|
||||
#include "moe_gemm_api.h"
|
||||
#include "weight_variant.h"
|
||||
|
||||
// Switch helpers inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#define HIDDEN_DTYPE_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using ActivationDtype = __half; \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kFP16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using ActivationDtype = __nv_bfloat16; \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kBF16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void moe_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
at::Tensor& total_rows_before_expert,
|
||||
int activation_raw)
|
||||
{
|
||||
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
|
||||
"Output and hidden states must have the same dtype");
|
||||
TORCH_CHECK(output.dtype() == weight.dtype(), "Output and weight must have the same dtype");
|
||||
|
||||
int64_t total_rows = hidden_states.size(0);
|
||||
int64_t gemm_k = hidden_states.size(1);
|
||||
int64_t gemm_n = weight.size(2);
|
||||
int num_experts = weight.size(0);
|
||||
|
||||
TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch");
|
||||
TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch");
|
||||
TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch");
|
||||
TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch");
|
||||
|
||||
HIDDEN_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
|
||||
fastertransformer::MoeGemmRunner<ActivationDtype, WVariant> runner =
|
||||
*MoeGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
|
||||
|
||||
ActivationType activation_type = (ActivationType)activation_raw;
|
||||
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
|
||||
runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(char*)weight.data_ptr(),
|
||||
nullptr,
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
(int64_t*)total_rows_before_expert.data_ptr(),
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
} else {
|
||||
ActivationDtype* bias_ptr = nullptr;
|
||||
if (bias.has_value()) {
|
||||
bias_ptr = (ActivationDtype*)bias.value().data_ptr();
|
||||
TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch");
|
||||
TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch");
|
||||
}
|
||||
runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(char*)weight.data_ptr(),
|
||||
nullptr,
|
||||
bias_ptr,
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
(int64_t*)total_rows_before_expert.data_ptr(),
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
activation_type,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define ACT_DTYPE_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using ActivationDtype = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using ActivationDtype = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define WEIGHT_VARIANT_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kFP8; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr WeightVariant WVariant = WeightVariant::kFP4; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void mixed_moe_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& scales,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
at::Tensor& total_rows_before_expert,
|
||||
int num_bits,
|
||||
int activation_raw)
|
||||
{
|
||||
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
|
||||
"Output and hidden states must have the same dtype");
|
||||
|
||||
int64_t total_rows = hidden_states.size(0);
|
||||
int64_t gemm_k = hidden_states.size(1);
|
||||
int64_t gemm_n = weight.size(2);
|
||||
int num_experts = weight.size(0);
|
||||
|
||||
TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch");
|
||||
TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch");
|
||||
TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch");
|
||||
TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch");
|
||||
|
||||
ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
|
||||
WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] {
|
||||
fastertransformer::MoeGemmRunner<ActivationDtype, WVariant> runner =
|
||||
*MoeGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
|
||||
|
||||
ActivationType activation_type = (ActivationType)activation_raw;
|
||||
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
|
||||
runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(char*)weight.data_ptr(),
|
||||
(ActivationDtype*)scales.data_ptr(),
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
(int64_t*)total_rows_before_expert.data_ptr(),
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
} else {
|
||||
ActivationDtype* bias_ptr = nullptr;
|
||||
if (bias.has_value()) {
|
||||
bias_ptr = (ActivationDtype*)bias.value().data_ptr();
|
||||
TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch");
|
||||
TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch");
|
||||
}
|
||||
runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
|
||||
(char*)weight.data_ptr(),
|
||||
(ActivationDtype*)scales.data_ptr(),
|
||||
bias_ptr,
|
||||
(ActivationDtype*)output.data_ptr(),
|
||||
(int64_t*)total_rows_before_expert.data_ptr(),
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
activation_type,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void moe_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
at::Tensor& total_rows_before_expert,
|
||||
int activation_raw);
|
||||
|
||||
void mixed_moe_gemm(at::Tensor& output,
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& scales,
|
||||
c10::optional<at::Tensor>& bias,
|
||||
at::Tensor& total_rows_before_expert,
|
||||
int num_bits,
|
||||
int activation_raw);
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from ....inference_utils import ActivationType, DtypeEnum
|
||||
from deepspeed.ops.op_builder import InferenceCutlassBuilder
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MoEGEMM(DSKernelBase):
|
||||
"""
|
||||
CUTLASS implementation of MoE GEMM.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
|
||||
|
||||
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType) -> None:
|
||||
|
||||
if not isinstance(fp_dtype, DtypeEnum):
|
||||
fp_dtype = DtypeEnum(fp_dtype)
|
||||
|
||||
if fp_dtype not in MoEGEMM.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, MoEGEMM.supported_dtypes))
|
||||
|
||||
if act_fn not in MoEGEMM.supported_act_fns:
|
||||
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
|
||||
act_fn, MoEGEMM.supported_act_fns))
|
||||
|
||||
inf_module = InferenceCutlassBuilder().load()
|
||||
self.kernel = inf_module.moe_gemm
|
||||
self.act_fn = act_fn
|
||||
|
||||
def __call__(self,
|
||||
ordered_output: torch.Tensor,
|
||||
ordered_input: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
total_rows_before_expert: torch.Tensor,
|
||||
biases: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
|
||||
|
||||
Arguments:
|
||||
ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
|
||||
ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
|
||||
weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous.
|
||||
total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts].
|
||||
biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous.
|
||||
|
||||
Returns:
|
||||
ordered_output
|
||||
"""
|
||||
self.kernel(ordered_output, ordered_input, weights, biases, total_rows_before_expert, self.act_fn)
|
||||
return ordered_output
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "activation_type.h"
|
||||
#include "weight_variant.h"
|
||||
|
||||
namespace fastertransformer {
|
||||
|
||||
template <typename T, /*The type used for activations/scales/compute*/
|
||||
WeightVariant V /* The type for the MoE weights */>
|
||||
class MoeGemmRunner {
|
||||
public:
|
||||
MoeGemmRunner();
|
||||
|
||||
void moe_gemm_bias_act(const T* A,
|
||||
const char* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
ActivationType activation_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
void moe_gemm(const T* A,
|
||||
const char* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
int multi_processor_count_;
|
||||
};
|
||||
|
||||
} // namespace fastertransformer
|
||||
|
||||
template <typename T, WeightVariant V>
|
||||
class MoeGemmContext {
|
||||
public:
|
||||
MoeGemmContext() { _runner = new fastertransformer::MoeGemmRunner<T, V>(); }
|
||||
|
||||
virtual ~MoeGemmContext() { delete _runner; }
|
||||
|
||||
static MoeGemmContext& Instance()
|
||||
{
|
||||
static MoeGemmContext _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
fastertransformer::MoeGemmRunner<T, V>* GeMM_Runner() const { return _runner; }
|
||||
|
||||
fastertransformer::MoeGemmRunner<T, V>* _runner;
|
||||
};
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
// Data structure that allows us to abstract internal CUTLASS datatypes/mappings
|
||||
// to the DeepSpeed-Kernels repo.
|
||||
|
||||
#pragma once
|
||||
|
||||
enum WeightVariant { kFP16, kBF16, kFP8, kFP4 };
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DSKernelBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
If necessary trigger compilation and warmup
|
||||
Autotuning of the kernel would happen at this stage to
|
||||
eliminate any potential hangs that might occur mid-deployment
|
||||
Validate that the desired run configuration is compatible.
|
||||
|
||||
It is not necessary to call super on this method.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
However the kernel needs to be called, it can be called here. Auto-tuning
|
||||
should never be performed here.
|
||||
|
||||
All inputs/outputs should be passed as arguments to this function. No allocations
|
||||
should be performed here.
|
||||
"""
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .atom_builder import *
|
||||
from .blocked_flash import *
|
||||
from .embed import *
|
||||
from .linear_blocked_kv_rotary import *
|
||||
from .logits_gather import *
|
||||
from .moe_gather import *
|
||||
from .moe_scatter import *
|
||||
from .top_1_gating import *
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .atom_builder import *
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "atom_builder.h"
|
||||
#include "attention_atom.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
int32_t build_atoms(torch::Tensor& atoms_ten,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& kv_ptrs,
|
||||
const int32_t q_block_size,
|
||||
const int32_t kv_block_size)
|
||||
{
|
||||
const RaggedBatchDescriptor* batch_desc =
|
||||
reinterpret_cast<const RaggedBatchDescriptor*>(batch_metadata.data_ptr());
|
||||
|
||||
const InflightSeqDescriptor* seq_desc =
|
||||
reinterpret_cast<const InflightSeqDescriptor*>(seq_metadata.data_ptr());
|
||||
|
||||
int32_t** kv_ptr_list = reinterpret_cast<int32_t**>(kv_ptrs.data_ptr());
|
||||
|
||||
AttentionAtom* atoms = reinterpret_cast<AttentionAtom*>(atoms_ten.data_ptr());
|
||||
|
||||
int32_t n_atoms = 0;
|
||||
for (int i = 0; i < batch_desc->n_sequences; i++) {
|
||||
const int seq_atoms = (seq_desc[i].n_tokens + q_block_size - 1) / q_block_size;
|
||||
int32_t cur_start_idx = seq_desc[i].start_idx;
|
||||
int32_t global_start_idx = seq_desc[i].seen_tokens;
|
||||
int32_t remaining_toks = seq_desc[i].n_tokens;
|
||||
|
||||
for (int j = 0; j < seq_atoms; j++) {
|
||||
atoms[n_atoms].block_idx_list = kv_ptr_list[i];
|
||||
atoms[n_atoms].q_start_idx = cur_start_idx;
|
||||
atoms[n_atoms].q_len = std::min(remaining_toks, q_block_size);
|
||||
atoms[n_atoms].global_q_idx = global_start_idx;
|
||||
|
||||
const int32_t end_toks = global_start_idx + atoms[n_atoms].q_len;
|
||||
// TODO(cmikeh2): This logic needs to be changed for sparse implementations
|
||||
atoms[n_atoms].kv_blocks = (end_toks + kv_block_size - 1) / kv_block_size;
|
||||
atoms[n_atoms].total_extent = end_toks;
|
||||
|
||||
cur_start_idx += atoms[n_atoms].q_len;
|
||||
global_start_idx += atoms[n_atoms].q_len;
|
||||
remaining_toks -= atoms[n_atoms].q_len;
|
||||
n_atoms++;
|
||||
}
|
||||
}
|
||||
|
||||
return n_atoms;
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
/*
|
||||
Construct the attention atoms given the ragged metadata for the current batch.
|
||||
This could largely be done at the Python level, but since we pack the KV ptr
|
||||
alongside the int32_t metadata, it gets very ugly to handle the mixed-width
|
||||
data structures (since we're packing them in a single tensor).
|
||||
*/
|
||||
int32_t build_atoms(torch::Tensor& atoms_ten,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& kv_ptrs,
|
||||
const int32_t q_block_size,
|
||||
const int32_t kv_block_size);
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
|
||||
|
||||
class AtomBuilder(DSKernelBase):
|
||||
"""
|
||||
C++ implementation to populate the attention atoms for the blocked attention
|
||||
kernel.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Triggers compilation of the C++ implementation.
|
||||
"""
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.build_atoms
|
||||
|
||||
def __call__(self, atoms: torch.Tensor, ragged_batch: RaggedBatchWrapper, q_block_size: int,
|
||||
kv_block_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Populates the attention atoms for the blocked attention kernel.
|
||||
|
||||
Args:
|
||||
atoms (torch.Tensor): Pre-allocated int32 tensor of shape [max_atoms, 8]
|
||||
ragged_batch (torch.Tensor): Wrapper for the ragged batch.
|
||||
q_block_size (int): The block size for the queries (as determined by the
|
||||
attention implementation)
|
||||
kv_block_size (int): The block size for the keys/values (as determined by the
|
||||
attention implementation)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if atoms.device != torch.device("cpu"):
|
||||
raise RuntimeError("AtomBuilder must be called on tensors")
|
||||
|
||||
n_atoms = self.kernel(atoms, ragged_batch.batch_metadata_buffer(on_device=False),
|
||||
ragged_batch.inflight_seq_descriptors(on_device=False),
|
||||
ragged_batch.kv_ptrs(on_device=False), q_block_size, kv_block_size)
|
||||
return atoms, n_atoms
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .blocked_flash import *
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include "cuda.h"
|
||||
|
||||
struct AttentionAtom {
|
||||
/*
|
||||
The attention atom describes the workload of a particular query. The attention
|
||||
kernel will execute each ``AttentionAtom`` for each head of the model.
|
||||
*/
|
||||
|
||||
// Pointer to a list of KV block indices.
|
||||
int32_t* block_idx_list;
|
||||
|
||||
// Index of first token in the ragged batch associated with this atom.
|
||||
int32_t q_start_idx;
|
||||
|
||||
// Number of tokens in the ragged batch associated with this atom.
|
||||
int32_t q_len;
|
||||
|
||||
// Number of key/value blocks associated with this atom. All but the last are
|
||||
// assumed to be fully dense.
|
||||
int32_t kv_blocks;
|
||||
|
||||
// Number of tokens in the last key/value block.
|
||||
int32_t total_extent;
|
||||
|
||||
// Global index of the first token in the atom. For example, in a prompt continuation
|
||||
// in which we have already processed 768 tokens, this would be 768.
|
||||
int32_t global_q_idx;
|
||||
|
||||
// Unused
|
||||
int32_t unused;
|
||||
};
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "blocked_flash.h"
|
||||
#include "flash.h"
|
||||
|
||||
#define CHECK_SHAPE(x, ...) \
|
||||
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
|
||||
#x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
void flash_attn_by_atoms(at::Tensor& out,
|
||||
at::Tensor& q,
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& attention_atoms,
|
||||
const float softmax_scale,
|
||||
const bool is_causal)
|
||||
{
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
const int total_q = q.size(0);
|
||||
const int head_size = k.size(-1);
|
||||
const int num_heads_kv = k.size(-2);
|
||||
const int num_heads_q = q.size(-1) / head_size;
|
||||
|
||||
TORCH_CHECK(head_size <= 256, "head_size must be <= 256");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size must be divisible by 8");
|
||||
TORCH_CHECK(num_heads_q % num_heads_kv == 0, "num_heads_q must be divisible by num_heads_kv");
|
||||
|
||||
Flash_fwd_params params;
|
||||
|
||||
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.atoms = reinterpret_cast<AttentionAtom*>(attention_atoms.data_ptr());
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_row_stride = q.stride(0);
|
||||
params.k_row_stride = k.stride(1);
|
||||
params.v_row_stride = v.stride(1);
|
||||
params.o_row_stride = out.stride(0);
|
||||
|
||||
// Assume heads are contiguous.
|
||||
params.q_head_stride = head_size;
|
||||
params.k_head_stride = head_size;
|
||||
params.v_head_stride = head_size;
|
||||
params.o_head_stride = head_size;
|
||||
|
||||
// Head params
|
||||
params.h = num_heads_q;
|
||||
params.h_k = num_heads_kv;
|
||||
params.h_h_k_ratio = num_heads_q / num_heads_kv;
|
||||
params.d = head_size;
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
params.d_rounded = round_multiple(head_size, 32);
|
||||
params.num_atoms = attention_atoms.size(0);
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
|
||||
params.is_causal = is_causal;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void flash_attn_by_atoms(at::Tensor& out,
|
||||
at::Tensor& q,
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& attention_atoms,
|
||||
const float softmax_scale,
|
||||
const bool is_causal);
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ....inference_utils import DtypeEnum
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
def get_q_block_size(head_size: int) -> int:
|
||||
"""
|
||||
Returns the query block size required by the kernel given a head size.
|
||||
"""
|
||||
cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda
|
||||
|
||||
if cc_major < 8:
|
||||
raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0")
|
||||
|
||||
if head_size <= 64:
|
||||
return 128
|
||||
elif head_size <= 160:
|
||||
if cc_minor != 0:
|
||||
return 64
|
||||
else:
|
||||
return 128
|
||||
elif head_size == 192:
|
||||
return 128
|
||||
elif head_size == 224:
|
||||
if cc_minor != 0:
|
||||
return 64
|
||||
else:
|
||||
return 128
|
||||
else:
|
||||
if cc_major == 8 and cc_minor == 0:
|
||||
return 128
|
||||
else:
|
||||
return 64
|
||||
|
||||
|
||||
def get_kv_block_size(head_size: int) -> int:
|
||||
"""
|
||||
Return preferred granulatity for blocked KV-cache implementation.
|
||||
"""
|
||||
cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda
|
||||
|
||||
if cc_major < 8:
|
||||
raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0")
|
||||
|
||||
if head_size <= 64:
|
||||
return 128
|
||||
elif head_size != 160 or cc_minor != 0:
|
||||
return 64
|
||||
else:
|
||||
return 32
|
||||
|
||||
|
||||
class BlockedFlashAttn(DSKernelBase):
|
||||
"""
|
||||
Modified implementation of flash-attn-2 tuned for inference on blocked KV-cache and wider
|
||||
range of input sequence lengths.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
|
||||
def __init__(self, head_size: int, dtype: DtypeEnum) -> None:
|
||||
"""
|
||||
Triggers any compilation of the kernels.
|
||||
"""
|
||||
if not isinstance(dtype, DtypeEnum):
|
||||
dtype = DtypeEnum(dtype)
|
||||
|
||||
if dtype not in BlockedFlashAttn.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported data types are {}".format(
|
||||
dtype, BlockedFlashAttn.supported_dtypes))
|
||||
|
||||
# For testing, need to revert to 32
|
||||
if head_size % 16 != 0:
|
||||
raise ValueError("Head size must be divisible by 32 (configured with {})".format(head_size))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.flash_attn_by_atoms
|
||||
|
||||
def __call__(self, out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, atoms: torch.Tensor,
|
||||
softmax_scale: float) -> torch.Tensor:
|
||||
"""
|
||||
Flash attention implementation atop a blocked KV-cache. Atoms should be pre-populated.
|
||||
See attention_atom.h for further details on the structure of the information.
|
||||
|
||||
Arguments:
|
||||
out (torch.Tensor): Output tensor of shape [tokens, hidden_size]
|
||||
q (torch.Tensor): Query tensor of shape [tokens, hidden_size]
|
||||
k (torch.Tensor): Key cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension.
|
||||
v (torch.Tensor): Value cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension.
|
||||
atoms (torch.Tensor): Atom information tensor of shape [num_atoms, 8] and type int32.
|
||||
Not all data is readable in this format. See attention_atom.h for further details.
|
||||
softmax_scale (float): Softmax scale factor.
|
||||
|
||||
Returns:
|
||||
out (torch.Tensor): Output tensor of shape [tokens, hidden_size]
|
||||
"""
|
||||
self.kernel(out, q, k, v, atoms, softmax_scale, True)
|
||||
return out
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/******************************************************************************
|
||||
Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#include "attention_atom.h"
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
// The QKV matrices.
|
||||
void* __restrict__ q_ptr;
|
||||
void* __restrict__ k_ptr;
|
||||
void* __restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
// The O matrix (output).
|
||||
void* __restrict__ o_ptr;
|
||||
|
||||
// The attention metadata
|
||||
AttentionAtom* __restrict__ atoms;
|
||||
|
||||
// Total attention atoms
|
||||
int num_atoms;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The dimensions
|
||||
int d, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream);
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .embed import RaggedEmbeddingKernel
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "embed.h"
|
||||
#include "ragged_kernel_helpers.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using float_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using float_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kBFloat16) { \
|
||||
using float_t = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dispatch type"); \
|
||||
} \
|
||||
}()
|
||||
#else
|
||||
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kFloat32) { \
|
||||
using float_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kFloat16) { \
|
||||
using float_t = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dispatch type"); \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define DISPATCH_FOR_INT(DTYPE, ...) \
|
||||
[&] { \
|
||||
if (DTYPE == torch::kInt32) { \
|
||||
using int_t = int32_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (DTYPE == torch::kInt64) { \
|
||||
using int_t = int64_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported dispatch type"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
/*
|
||||
Embeddings kernel aware of ragged batch structure.
|
||||
*/
|
||||
void ragged_embed(torch::Tensor& embedded_tokens,
|
||||
torch::Tensor& input_ids,
|
||||
torch::Tensor& embedding_weight,
|
||||
c10::optional<torch::Tensor>& position_embedding_weight,
|
||||
int32_t pos_embed_offset,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs)
|
||||
{
|
||||
// We don't care about KV cache here, so just hardcoding 0s for block_size/num_blocks
|
||||
BatchWrapperCPP batch_wrapper =
|
||||
make_cpp_batch_wrapper(batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, 0, 0);
|
||||
|
||||
const int32_t n_tokens = input_ids.numel();
|
||||
const int32_t embed_dim = embedding_weight.size(1);
|
||||
const int32_t vocab_size = embedding_weight.size(0);
|
||||
|
||||
DISPATCH_FOR_INT(input_ids.scalar_type(), [&] {
|
||||
DISPATCH_FOR_FLOAT(embedding_weight.scalar_type(), [&] {
|
||||
float_t* pos_embed_ptr = nullptr;
|
||||
int32_t max_position_embed_idx = 0;
|
||||
if (position_embedding_weight.has_value()) {
|
||||
TORCH_CHECK(
|
||||
position_embedding_weight.value().options().dtype() ==
|
||||
embedding_weight.options().dtype(),
|
||||
"position_embedding_weight and embedding_weight must have the same dtype");
|
||||
pos_embed_ptr =
|
||||
reinterpret_cast<float_t*>(position_embedding_weight.value().data_ptr());
|
||||
max_position_embed_idx = position_embedding_weight.value().size(0) - 1;
|
||||
}
|
||||
|
||||
launch_ragged_embed_kernel((float_t*)embedded_tokens.data_ptr(),
|
||||
(const int_t*)input_ids.data_ptr(),
|
||||
(const float_t*)embedding_weight.data_ptr(),
|
||||
pos_embed_ptr,
|
||||
batch_wrapper,
|
||||
n_tokens,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embed_idx,
|
||||
pos_embed_offset,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "embed.cuh"
|
||||
#include "memory_access_utils.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
namespace embed {
|
||||
|
||||
constexpr int granularity = 16;
|
||||
constexpr int threads = 512;
|
||||
|
||||
} // namespace embed
|
||||
|
||||
template <typename TokenType, typename EmbedType>
|
||||
__global__ void ragged_embed_kernel(EmbedType* embedded_tokens,
|
||||
const TokenType* input_ids,
|
||||
const EmbedType* embedding_weight,
|
||||
const EmbedType* position_weight,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int32_t embed_dim,
|
||||
const int32_t vocab_size,
|
||||
const int32_t max_position_embed_idx,
|
||||
const int32_t position_embed_offset)
|
||||
{
|
||||
constexpr int T_vector = embed::granularity / sizeof(EmbedType);
|
||||
|
||||
const int32_t token_idx = blockIdx.y;
|
||||
|
||||
// It's possible our batch is padded (under CG conditions typically)
|
||||
if (token_idx >= batch_desc.batch_metadata->n_tokens) return;
|
||||
|
||||
TokenType token_value = input_ids[token_idx];
|
||||
|
||||
if (token_value >= vocab_size || token_value < 0) {
|
||||
// TODO(cmikeh2): This is invalid, but not sure how we want to handle it being invalid
|
||||
// yet.
|
||||
return;
|
||||
}
|
||||
|
||||
const EmbedType* embedding_row = embedding_weight + token_value * embed_dim;
|
||||
EmbedType* dest_row = embedded_tokens + token_idx * embed_dim;
|
||||
|
||||
const int channel_offset = (threadIdx.x + embed::threads * blockIdx.x) * T_vector;
|
||||
|
||||
if (channel_offset < embed_dim) {
|
||||
EmbedType reg_buf[T_vector];
|
||||
|
||||
mem_access::load_global<embed::granularity>(reg_buf, embedding_row + channel_offset);
|
||||
|
||||
if (position_weight != nullptr) {
|
||||
// Map the token to its global idx (indirect memory accesses aren't great but whatever)
|
||||
const int32_t seq_idx = batch_desc.tokens_to_seq[token_idx];
|
||||
const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_idx];
|
||||
int32_t pos_emb_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx);
|
||||
|
||||
// Position embed offset is an OPT-specific feature I think?
|
||||
pos_emb_idx = pos_emb_idx + position_embed_offset;
|
||||
|
||||
// This clamping is technically
|
||||
pos_emb_idx = (pos_emb_idx < 0) ? 0 : pos_emb_idx;
|
||||
pos_emb_idx = (pos_emb_idx >= max_position_embed_idx) ? max_position_embed_idx
|
||||
: pos_emb_idx;
|
||||
|
||||
const EmbedType* position_embedding_row = position_weight + pos_emb_idx * embed_dim;
|
||||
|
||||
EmbedType pos_buf[T_vector];
|
||||
mem_access::load_global<embed::granularity>(pos_buf,
|
||||
position_embedding_row + channel_offset);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < T_vector; i++) { reg_buf[i] += pos_buf[i]; }
|
||||
}
|
||||
|
||||
mem_access::store_global<embed::granularity>(dest_row + channel_offset, reg_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TokenType, typename EmbedType>
|
||||
void launch_ragged_embed_kernel(EmbedType* embedded_tokens,
|
||||
const TokenType* input_ids,
|
||||
const EmbedType* embedding_weight,
|
||||
const EmbedType* position_weight,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int32_t n_tokens,
|
||||
const int32_t embed_dim,
|
||||
const int32_t vocab_size,
|
||||
const int32_t max_position_embed_idx,
|
||||
const int32_t position_embed_offset,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int T_vector = embed::granularity / sizeof(EmbedType);
|
||||
constexpr int elems_per_block = embed::threads * T_vector;
|
||||
const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block;
|
||||
|
||||
const dim3 grid_dim(parallel_blocks, n_tokens, 1);
|
||||
const dim3 block_dim(embed::threads, 1, 1);
|
||||
|
||||
ragged_embed_kernel<TokenType, EmbedType>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(embedded_tokens,
|
||||
input_ids,
|
||||
embedding_weight,
|
||||
position_weight,
|
||||
batch_desc,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embed_idx,
|
||||
position_embed_offset);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_EMBED_FOR_TYPES(TOKEN_TYPE, EMBED_TYPE) \
|
||||
template void launch_ragged_embed_kernel<TOKEN_TYPE, EMBED_TYPE>( \
|
||||
EMBED_TYPE * embedded_tokens, \
|
||||
const TOKEN_TYPE* input_ids, \
|
||||
const EMBED_TYPE* embedding_weight, \
|
||||
const EMBED_TYPE* position_weight, \
|
||||
const BatchWrapperCPP batch_descriptor, \
|
||||
const int32_t n_tokens, \
|
||||
const int32_t embed_dim, \
|
||||
const int32_t vocab_size, \
|
||||
const int32_t max_position_embed_idx, \
|
||||
const int32_t position_embed_offset, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int32_t, float)
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int64_t, float)
|
||||
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int32_t, __half)
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int64_t, __half)
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int32_t, __nv_bfloat16)
|
||||
INSTANTIATE_EMBED_FOR_TYPES(int64_t, __nv_bfloat16)
|
||||
#endif
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
template <typename TokenType, typename EmbedType>
|
||||
void launch_ragged_embed_kernel(EmbedType* embedded_tokens,
|
||||
const TokenType* input_ids,
|
||||
const EmbedType* embedding_weight,
|
||||
const EmbedType* position_weight,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int32_t n_tokens,
|
||||
const int32_t embed_dim,
|
||||
const int32_t vocab_size,
|
||||
const int32_t max_position_embed_idx,
|
||||
const int32_t position_embed_offset,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "embed.cuh"
|
||||
|
||||
/*
|
||||
Embeddings kernel aware of ragged batch structure.
|
||||
*/
|
||||
void ragged_embed(torch::Tensor& embedded_tokens,
|
||||
torch::Tensor& input_ids,
|
||||
torch::Tensor& embedding_weight,
|
||||
c10::optional<torch::Tensor>& position_weight,
|
||||
int32_t position_embed_offset,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs);
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ....inference_utils import elem_size
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
|
||||
|
||||
class RaggedEmbeddingKernel(DSKernelBase):
|
||||
"""
|
||||
Ragged-aware CUDA kernel implementation for an embedding lookup. This will only lookup
|
||||
the necessary tokens for a padded batch (i.e. if we are CGed and running with a slightly
|
||||
larger batch size than the actual tokens).
|
||||
"""
|
||||
|
||||
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
supported_token_dtypes = [torch.int32, torch.int64]
|
||||
|
||||
def __init__(self, embed_dtype: torch.dtype, token_dtype: torch.dtype, embed_dim: int) -> None:
|
||||
"""
|
||||
Args:
|
||||
fp_dtype (torch.dtype): Data type of the embedding table and output dtype.
|
||||
Supported values are torch.float16, torch.bfloat16, and torch.float32.
|
||||
token_dtype (torch.dtype): Data type of the token ids. Supported values are
|
||||
torch.int32 and torch.int64.
|
||||
embed_dim (int): Embedding dimension. Must be aligned to 16 bytes.
|
||||
"""
|
||||
if embed_dtype not in RaggedEmbeddingKernel.supported_dtypes:
|
||||
raise ValueError("Unsupported embedding data type: {}, supported_dtypes are {}".format(
|
||||
embed_dtype, RaggedEmbeddingKernel.supported_dtypes))
|
||||
|
||||
if token_dtype not in RaggedEmbeddingKernel.supported_token_dtypes:
|
||||
raise ValueError("Unsupported token data type: {}, supported_dtypes are {}".format(
|
||||
token_dtype, RaggedEmbeddingKernel.supported_token_dtypes))
|
||||
|
||||
if elem_size(embed_dtype) * embed_dim % 16 != 0:
|
||||
raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(embed_dim))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.ragged_embed
|
||||
|
||||
def __call__(self,
|
||||
embedded_tokens: torch.Tensor,
|
||||
ragged_wrapper: RaggedBatchWrapper,
|
||||
embedding_weight: torch.Tensor,
|
||||
position_embed_weight: Optional[torch.Tensor] = None,
|
||||
position_embed_offset: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Ragged aware embedding lookup.
|
||||
|
||||
Args:
|
||||
embedded_tokens (torch.Tensor): Output tensor of shape [num_tokens, embed_dim]
|
||||
ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch.
|
||||
embedding_weight (torch.Tensor): Embedding table of shape [vocab_size, embed_dim]
|
||||
"""
|
||||
self.kernel(embedded_tokens, ragged_wrapper.input_ids(),
|
||||
embedding_weight, position_embed_weight, position_embed_offset,
|
||||
ragged_wrapper.batch_metadata_buffer(), ragged_wrapper.inflight_seq_descriptors(),
|
||||
ragged_wrapper.tokens_to_seq(), ragged_wrapper.kv_ptrs())
|
||||
return embedded_tokens
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .blocked_kv_rotary import *
|
||||
from .blocked_trained_kv_rotary import *
|
||||
from .linear_blocked_kv_copy import *
|
|
@ -0,0 +1,188 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "blocked_kv_rotary.h"
|
||||
#include "ragged_kernel_helpers.h"
|
||||
|
||||
#define DISPATCH_KV_ROTARY(T_TYPE, C_TYPE) \
|
||||
if (q.options().dtype() == torch::T_TYPE) { \
|
||||
launch_kv_rotary_kernel<C_TYPE>((C_TYPE*)kv_cache.data_ptr(), \
|
||||
(C_TYPE*)q.data_ptr(), \
|
||||
(C_TYPE*)k.data_ptr(), \
|
||||
(C_TYPE*)v.data_ptr(), \
|
||||
(C_TYPE*)inv_freq_ptr, \
|
||||
batch_wrapper, \
|
||||
qkv_stride, \
|
||||
kv_cache_stride, \
|
||||
v_offset, \
|
||||
inv_freq_stride, \
|
||||
q_ratio, \
|
||||
head_size, \
|
||||
n_tokens, \
|
||||
n_q_heads, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
/*
|
||||
Rotary position embeddings + copy into KV cache. This implementation assumes
|
||||
that the inverse frequencies should be ready from global memory rather than
|
||||
synthesized in the kernel.
|
||||
|
||||
Arguments:
|
||||
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
q: [n_tokens, n_q_heads * head_size]
|
||||
k: [n_tokens, n_kv_heads * head_size]
|
||||
v: [n_tokens, n_kv_heads * head_size]
|
||||
inv_freq: [max_seq_len, head_size // 2]
|
||||
*/
|
||||
void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& inv_freq,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs)
|
||||
{
|
||||
const int32_t n_tokens = q.size(0);
|
||||
TORCH_CHECK(n_tokens == k.size(0));
|
||||
TORCH_CHECK(n_tokens == v.size(0));
|
||||
|
||||
// Dimensions
|
||||
const int32_t block_size = kv_cache.size(1);
|
||||
const int32_t n_kv_heads = kv_cache.size(3);
|
||||
const int32_t head_size = kv_cache.size(4);
|
||||
|
||||
// Strides
|
||||
const int32_t qkv_stride = q.stride(0); // Per token
|
||||
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
|
||||
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
|
||||
const int32_t inv_freq_stride = inv_freq.stride(0); // Per token idx
|
||||
|
||||
const int n_q_heads = q.size(1) / head_size;
|
||||
const int q_ratio = n_q_heads / n_kv_heads;
|
||||
|
||||
void* inv_freq_ptr = (void*)inv_freq.data_ptr();
|
||||
|
||||
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
|
||||
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
|
||||
|
||||
DISPATCH_KV_ROTARY(kHalf, __half);
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
Rotary position embeddings + copy into KV cache. This implementation assumes
|
||||
that the inverse frequencies should be synthesized in the kernel.
|
||||
|
||||
Arguments:
|
||||
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
q: [n_tokens, n_q_heads * head_size]
|
||||
k: [n_tokens, n_kv_heads * head_size]
|
||||
v: [n_tokens, n_kv_heads * head_size]
|
||||
*/
|
||||
void kv_rotary_embeddings(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs)
|
||||
{
|
||||
const int32_t n_tokens = q.size(0);
|
||||
TORCH_CHECK(n_tokens == k.size(0));
|
||||
TORCH_CHECK(n_tokens == v.size(0));
|
||||
|
||||
// Dimensions
|
||||
const int32_t block_size = kv_cache.size(1);
|
||||
const int32_t n_kv_heads = kv_cache.size(3);
|
||||
const int32_t head_size = kv_cache.size(4);
|
||||
|
||||
// Strides
|
||||
const int32_t qkv_stride = q.stride(0); // Per token
|
||||
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
|
||||
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
|
||||
const int32_t inv_freq_stride = 0; // Per token idx
|
||||
|
||||
const int n_q_heads = q.size(1) / head_size;
|
||||
const int q_ratio = n_q_heads / n_kv_heads;
|
||||
|
||||
void* inv_freq_ptr = nullptr;
|
||||
|
||||
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
|
||||
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
|
||||
|
||||
DISPATCH_KV_ROTARY(kHalf, __half);
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_KV_COPY(T_TYPE, C_TYPE) \
|
||||
if (q.options().dtype() == torch::T_TYPE) { \
|
||||
launch_kv_copy_kernel<C_TYPE>((C_TYPE*)kv_cache.data_ptr(), \
|
||||
(C_TYPE*)q.data_ptr(), \
|
||||
(C_TYPE*)k.data_ptr(), \
|
||||
(C_TYPE*)v.data_ptr(), \
|
||||
batch_wrapper, \
|
||||
qkv_stride, \
|
||||
kv_cache_stride, \
|
||||
v_offset, \
|
||||
q_ratio, \
|
||||
head_size, \
|
||||
n_tokens, \
|
||||
n_q_heads, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
/*
|
||||
Copy into linear KV cache.
|
||||
*/
|
||||
void linear_kv_copy(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs)
|
||||
{
|
||||
const int32_t n_tokens = q.size(0);
|
||||
TORCH_CHECK(n_tokens == k.size(0));
|
||||
TORCH_CHECK(n_tokens == v.size(0));
|
||||
|
||||
// Dimensions
|
||||
const int32_t block_size = kv_cache.size(1);
|
||||
const int32_t n_kv_heads = kv_cache.size(3);
|
||||
const int32_t head_size = kv_cache.size(4);
|
||||
|
||||
// Strides
|
||||
const int32_t qkv_stride = q.stride(0); // Per token
|
||||
TORCH_CHECK(qkv_stride == k.stride(0));
|
||||
TORCH_CHECK(qkv_stride == v.stride(0));
|
||||
|
||||
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
|
||||
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
|
||||
|
||||
const int n_q_heads = q.size(1) / head_size;
|
||||
|
||||
TORCH_CHECK(n_q_heads % n_kv_heads == 0);
|
||||
const int q_ratio = n_q_heads / n_kv_heads;
|
||||
|
||||
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
|
||||
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
|
||||
|
||||
DISPATCH_KV_COPY(kHalf, __half);
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_KV_COPY(kBFloat16, __nv_bfloat16);
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,314 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "blocked_kv_rotary.cuh"
|
||||
#include "conversion_utils.h"
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "memory_access_utils.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace kv_rot {
|
||||
|
||||
constexpr int granularity = 16;
|
||||
constexpr int threads = 256;
|
||||
|
||||
} // namespace kv_rot
|
||||
|
||||
/*
|
||||
Supports head size 32, 64, 128, 256
|
||||
*/
|
||||
|
||||
template <typename T, int qRatio, int headSize, bool doRotary>
|
||||
__global__ void kv_rotary_pos_kernel(T* kv_cache,
|
||||
T* q,
|
||||
T* k,
|
||||
T* v,
|
||||
const T* inv_freq,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int qkv_stride,
|
||||
const int kv_cache_stride,
|
||||
const int v_offset,
|
||||
const int inv_freq_stride)
|
||||
{
|
||||
// Derived constexpr
|
||||
constexpr int vector_T = kv_rot::granularity / sizeof(T);
|
||||
constexpr int threads_per_head = headSize / vector_T;
|
||||
constexpr int half_head_size = headSize >> 1;
|
||||
constexpr int tokens_per_block = kv_rot::threads / threads_per_head;
|
||||
|
||||
// CG helpers
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
cg::thread_block_tile<threads_per_head> head_group =
|
||||
cg::tiled_partition<threads_per_head>(warp);
|
||||
|
||||
// Parallelize on the head dimension for X blocks
|
||||
const int head_idx = blockIdx.x;
|
||||
|
||||
const int block_seq_idx = threadIdx.x / threads_per_head;
|
||||
const int base_neuron_idx = (threadIdx.x * vector_T) % headSize;
|
||||
const int half_idx = base_neuron_idx % half_head_size;
|
||||
const int half_head_lanes = threads_per_head / 2;
|
||||
|
||||
// Multiple tokens processed by the same threadblock
|
||||
const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx;
|
||||
const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens;
|
||||
const bool load_inv_freq = (inv_freq != nullptr) && valid_token;
|
||||
|
||||
// If we have GQA, then only one of the Q heads needs to do rotary + copy
|
||||
// for each of the heads in the group.
|
||||
bool need_kv = head_idx % qRatio == 0;
|
||||
// Make sure the following code is warp uniform
|
||||
need_kv = warp.shfl(need_kv, 0);
|
||||
|
||||
const int kv_head_idx = head_idx / qRatio;
|
||||
|
||||
// Ensure we don't access invalid portions of the seq_metadata
|
||||
const int32_t seq_id = (valid_token) ? batch_desc.tokens_to_seq[token_idx] : 0;
|
||||
const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id];
|
||||
// This will give an invalid index if valid_token is false, but should never affect memory.
|
||||
const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx);
|
||||
|
||||
T* q_row = q + token_idx * qkv_stride + head_idx * headSize;
|
||||
T q_reg[vector_T];
|
||||
|
||||
if (need_kv) {
|
||||
// The following logic assumes a linearly blocked KV cache. This means that no sparsity has
|
||||
// been introduced into cache history.
|
||||
const KVCacheDescriptor kv_desc = batch_desc.kv_desc;
|
||||
const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size;
|
||||
const int32_t mapped_kv_block_idx =
|
||||
(valid_token) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0;
|
||||
|
||||
const int32_t kv_block_offset = global_token_idx % kv_desc.block_size;
|
||||
const int32_t kv_offset =
|
||||
(mapped_kv_block_idx * kv_desc.block_size + kv_block_offset) * kv_cache_stride +
|
||||
kv_head_idx * headSize;
|
||||
|
||||
// Load indices from QKV output
|
||||
T* k_row = k + token_idx * qkv_stride + kv_head_idx * headSize;
|
||||
T* v_row = v + token_idx * qkv_stride + kv_head_idx * headSize;
|
||||
|
||||
T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T];
|
||||
|
||||
mem_access::load_global<kv_rot::granularity>(q_reg, q_row + base_neuron_idx, valid_token);
|
||||
mem_access::load_global<kv_rot::granularity>(k_reg, k_row + base_neuron_idx, valid_token);
|
||||
mem_access::load_global<kv_rot::granularity>(v_reg, v_row + base_neuron_idx, valid_token);
|
||||
mem_access::load_global<kv_rot::granularity>(
|
||||
inv_freq_reg, inv_freq + half_idx, load_inv_freq);
|
||||
|
||||
if constexpr (doRotary) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vector_T; i++) {
|
||||
const int head_neuron_idx = base_neuron_idx + i;
|
||||
|
||||
float inv_freq_flt;
|
||||
if (inv_freq != nullptr) {
|
||||
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
|
||||
} else {
|
||||
inv_freq_flt =
|
||||
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
|
||||
// Conversion to T and back means that both branches of this if statement
|
||||
// will produce the same results if using the same algo for producing the
|
||||
// freqs.
|
||||
T trunc_freq = conversion::to<T>(1.0 / powf(10000.0, inv_freq_flt));
|
||||
inv_freq_flt = conversion::to<float>(trunc_freq) * (float)global_token_idx;
|
||||
}
|
||||
|
||||
float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
|
||||
float q_f = conversion::to<float>(q_reg[i]);
|
||||
float k_f = conversion::to<float>(k_reg[i]);
|
||||
float q_rot = q_f * rotary_sign;
|
||||
float k_rot = k_f * rotary_sign;
|
||||
|
||||
const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes);
|
||||
const float k_rot_temp = head_group.shfl_xor(k_rot, half_head_lanes);
|
||||
|
||||
q_reg[i] =
|
||||
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
|
||||
k_reg[i] =
|
||||
conversion::to<T>(k_f * cosf(inv_freq_flt) + k_rot_temp * sinf(inv_freq_flt));
|
||||
}
|
||||
}
|
||||
|
||||
if (valid_token) {
|
||||
mem_access::store_global<kv_rot::granularity>(kv_cache + kv_offset + base_neuron_idx,
|
||||
k_reg);
|
||||
mem_access::store_global<kv_rot::granularity>(
|
||||
kv_cache + kv_offset + base_neuron_idx + v_offset, v_reg);
|
||||
}
|
||||
} else {
|
||||
T inv_freq_reg[vector_T];
|
||||
|
||||
mem_access::load_global<kv_rot::granularity>(q_reg, q_row + base_neuron_idx, valid_token);
|
||||
mem_access::load_global<kv_rot::granularity>(
|
||||
inv_freq_reg, inv_freq + half_idx, load_inv_freq);
|
||||
|
||||
if constexpr (doRotary) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vector_T; i++) {
|
||||
const int head_neuron_idx = base_neuron_idx + i;
|
||||
|
||||
float inv_freq_flt;
|
||||
if (inv_freq != nullptr) {
|
||||
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
|
||||
} else {
|
||||
inv_freq_flt =
|
||||
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
|
||||
inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx;
|
||||
}
|
||||
|
||||
float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
|
||||
float q_f = conversion::to<float>(q_reg[i]);
|
||||
float q_rot = q_f * rotary_sign;
|
||||
|
||||
const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes);
|
||||
|
||||
q_reg[i] =
|
||||
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (valid_token && doRotary) {
|
||||
mem_access::store_global<kv_rot::granularity>(q_row + base_neuron_idx, q_reg);
|
||||
}
|
||||
}
|
||||
|
||||
#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE) \
|
||||
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
|
||||
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, true> \
|
||||
<<<grid, block, 0, stream>>>(kv_cache, \
|
||||
q, \
|
||||
k, \
|
||||
v, \
|
||||
inv_freq, \
|
||||
batch_desc, \
|
||||
qkv_stride, \
|
||||
kv_cache_stride, \
|
||||
v_offset, \
|
||||
inv_freq_stride);
|
||||
|
||||
template <typename T>
|
||||
void launch_kv_rotary_kernel(T* kv_cache,
|
||||
T* q,
|
||||
T* k,
|
||||
T* v,
|
||||
T* inv_freq,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int qkv_stride,
|
||||
const int kv_cache_stride,
|
||||
const int v_offset,
|
||||
const int inv_freq_stride,
|
||||
const int q_ratio,
|
||||
const int head_size,
|
||||
const int n_tokens,
|
||||
const int n_q_heads,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int vector_T = kv_rot::granularity / sizeof(T);
|
||||
const int threads_per_head = head_size / vector_T;
|
||||
const int tokens_per_block = kv_rot::threads / threads_per_head;
|
||||
|
||||
const dim3 block(kv_rot::threads);
|
||||
const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block;
|
||||
const dim3 grid(n_q_heads, token_blocks);
|
||||
|
||||
DISPATCH_KV_ROTARY_IMPL(1, 64)
|
||||
DISPATCH_KV_ROTARY_IMPL(1, 128)
|
||||
DISPATCH_KV_ROTARY_IMPL(2, 64)
|
||||
DISPATCH_KV_ROTARY_IMPL(2, 128)
|
||||
DISPATCH_KV_ROTARY_IMPL(4, 64)
|
||||
DISPATCH_KV_ROTARY_IMPL(4, 128)
|
||||
DISPATCH_KV_ROTARY_IMPL(5, 64)
|
||||
DISPATCH_KV_ROTARY_IMPL(5, 128)
|
||||
DISPATCH_KV_ROTARY_IMPL(8, 64)
|
||||
DISPATCH_KV_ROTARY_IMPL(8, 128)
|
||||
}
|
||||
|
||||
#define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \
|
||||
template void launch_kv_rotary_kernel<TYPE>(TYPE * kv_cache, \
|
||||
TYPE * q, \
|
||||
TYPE * k, \
|
||||
TYPE * v, \
|
||||
TYPE * inv_freq, \
|
||||
const BatchWrapperCPP batch_desc, \
|
||||
const int qkv_stride, \
|
||||
const int kv_cache_stride, \
|
||||
const int v_offset, \
|
||||
const int inv_freq_stride, \
|
||||
const int q_ratio, \
|
||||
const int head_size, \
|
||||
const int n_tokens, \
|
||||
const int n_q_heads, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_KV_ROTARY_KERNEL(__half)
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16)
|
||||
#endif
|
||||
|
||||
#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \
|
||||
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
|
||||
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, false><<<grid, block, 0, stream>>>( \
|
||||
kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0);
|
||||
|
||||
template <typename T>
|
||||
void launch_kv_copy_kernel(T* kv_cache,
|
||||
T* q,
|
||||
T* k,
|
||||
T* v,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int qkv_stride,
|
||||
const int kv_cache_stride,
|
||||
const int v_offset,
|
||||
const int q_ratio,
|
||||
const int head_size,
|
||||
const int n_tokens,
|
||||
const int n_q_heads,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int vector_T = kv_rot::granularity / sizeof(T);
|
||||
const int threads_per_head = head_size / vector_T;
|
||||
const int tokens_per_block = kv_rot::threads / threads_per_head;
|
||||
|
||||
const dim3 block(kv_rot::threads);
|
||||
const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block;
|
||||
const dim3 grid(n_q_heads, token_blocks);
|
||||
|
||||
DISPATCH_KV_COPY_IMPL(1, 64)
|
||||
DISPATCH_KV_COPY_IMPL(1, 128)
|
||||
DISPATCH_KV_COPY_IMPL(2, 64)
|
||||
DISPATCH_KV_COPY_IMPL(2, 128)
|
||||
DISPATCH_KV_COPY_IMPL(4, 64)
|
||||
DISPATCH_KV_COPY_IMPL(4, 128)
|
||||
DISPATCH_KV_COPY_IMPL(5, 64)
|
||||
DISPATCH_KV_COPY_IMPL(5, 128)
|
||||
DISPATCH_KV_COPY_IMPL(8, 64)
|
||||
DISPATCH_KV_COPY_IMPL(8, 128)
|
||||
}
|
||||
|
||||
#define INSTANTIATE_KV_COPY_KERNEL(TYPE) \
|
||||
template void launch_kv_copy_kernel<TYPE>(TYPE * kv_cache, \
|
||||
TYPE * q, \
|
||||
TYPE * k, \
|
||||
TYPE * v, \
|
||||
const BatchWrapperCPP batch_desc, \
|
||||
const int qkv_stride, \
|
||||
const int kv_cache_stride, \
|
||||
const int v_offset, \
|
||||
const int q_ratio, \
|
||||
const int head_size, \
|
||||
const int n_tokens, \
|
||||
const int n_q_heads, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_KV_COPY_KERNEL(__half)
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_KV_COPY_KERNEL(__nv_bfloat16)
|
||||
#endif
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void launch_kv_rotary_kernel(T* kv_cache,
|
||||
T* q,
|
||||
T* k,
|
||||
T* v,
|
||||
T* inv_freq,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int qkv_stride,
|
||||
const int kv_cache_stride,
|
||||
const int v_offset,
|
||||
const int inv_freq_stride,
|
||||
const int q_ratio,
|
||||
const int head_size,
|
||||
const int n_tokens,
|
||||
const int n_q_heads,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_kv_copy_kernel(T* kv_cache,
|
||||
T* q,
|
||||
T* k,
|
||||
T* v,
|
||||
const BatchWrapperCPP batch_desc,
|
||||
const int qkv_stride,
|
||||
const int kv_cache_stride,
|
||||
const int v_offset,
|
||||
const int q_ratio,
|
||||
const int head_size,
|
||||
const int n_tokens,
|
||||
const int n_q_heads,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "blocked_kv_rotary.cuh"
|
||||
|
||||
/*
|
||||
Rotary position embeddings + copy into KV cache. This implementation assumes
|
||||
that the inverse frequencies should be ready from global memory rather than
|
||||
synthesized in the kernel.
|
||||
|
||||
Arguments:
|
||||
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
q: [n_tokens, n_q_heads * head_size]
|
||||
k: [n_tokens, n_kv_heads * head_size]
|
||||
v: [n_tokens, n_kv_heads * head_size]
|
||||
inv_freq: [max_seq_len, head_size // 2]
|
||||
*/
|
||||
void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& inv_freq,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs);
|
||||
|
||||
/*
|
||||
Rotary position embeddings + copy into KV cache. This implementation assumes
|
||||
that the inverse frequencies should be synthesized in the kernel.
|
||||
|
||||
Arguments:
|
||||
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
q: [n_tokens, n_q_heads * head_size]
|
||||
k: [n_tokens, n_kv_heads * head_size]
|
||||
v: [n_tokens, n_kv_heads * head_size]
|
||||
*/
|
||||
void kv_rotary_embeddings(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs);
|
||||
|
||||
/*
|
||||
Copy into linear KV cache.
|
||||
*/
|
||||
void linear_kv_copy(torch::Tensor& kv_cache,
|
||||
torch::Tensor& q,
|
||||
torch::Tensor& k,
|
||||
torch::Tensor& v,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata,
|
||||
torch::Tensor& tokens_to_seq,
|
||||
torch::Tensor& kv_ptrs);
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ....inference_utils import DtypeEnum
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
class BlockedRotaryEmbeddings(DSKernelBase):
|
||||
"""
|
||||
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
|
||||
before copying into a blocked KV cache.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_head_sizes = [64, 128]
|
||||
supported_q_ratios = [1, 2, 4, 5, 8]
|
||||
|
||||
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Args:
|
||||
head_size: The size of the attention head.
|
||||
q_ratio: Ratio of q heads to kv heads (for GQA)
|
||||
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
|
||||
"""
|
||||
|
||||
q_ratio = n_q_heads // n_kv_heads
|
||||
|
||||
if head_size not in BlockedRotaryEmbeddings.supported_head_sizes:
|
||||
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
|
||||
head_size, BlockedRotaryEmbeddings.supported_head_sizes))
|
||||
|
||||
if q_ratio not in BlockedRotaryEmbeddings.supported_q_ratios:
|
||||
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
|
||||
q_ratio, BlockedRotaryEmbeddings.supported_q_ratios))
|
||||
|
||||
if not isinstance(dtype, DtypeEnum):
|
||||
dtype = DtypeEnum(dtype)
|
||||
|
||||
if dtype not in BlockedRotaryEmbeddings.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
dtype, BlockedRotaryEmbeddings.supported_dtypes))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.kv_rotary_embeddings
|
||||
self.head_size = head_size
|
||||
self.n_q_heads = n_q_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
|
||||
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
|
||||
"""
|
||||
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
|
||||
ragged_batch: Wrapper for the ragged batch.
|
||||
"""
|
||||
|
||||
q = qkv[:, :self.head_size * self.n_q_heads]
|
||||
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
|
||||
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
|
||||
|
||||
self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(),
|
||||
ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ....inference_utils import DtypeEnum
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
class BlockedTrainedRotaryEmbeddings(DSKernelBase):
|
||||
"""
|
||||
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
|
||||
before copying into a blocked KV cache.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_head_sizes = [64, 128]
|
||||
supported_q_ratios = [1, 2, 4, 5, 8]
|
||||
|
||||
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Args:
|
||||
head_size: The size of the attention head.
|
||||
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
|
||||
"""
|
||||
|
||||
q_ratio = n_q_heads // n_kv_heads
|
||||
|
||||
if head_size not in BlockedTrainedRotaryEmbeddings.supported_head_sizes:
|
||||
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
|
||||
head_size, BlockedTrainedRotaryEmbeddings.supported_head_sizes))
|
||||
|
||||
if q_ratio not in BlockedTrainedRotaryEmbeddings.supported_q_ratios:
|
||||
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
|
||||
q_ratio, BlockedTrainedRotaryEmbeddings.supported_q_ratios))
|
||||
|
||||
if not isinstance(dtype, DtypeEnum):
|
||||
dtype = DtypeEnum(dtype)
|
||||
|
||||
if dtype not in BlockedTrainedRotaryEmbeddings.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
dtype, BlockedTrainedRotaryEmbeddings.supported_dtypes))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.kv_trained_rotary_embeddings
|
||||
self.head_size = head_size
|
||||
self.n_q_heads = n_q_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
|
||||
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper,
|
||||
inverse_freqs: torch.Tensor) -> None:
|
||||
"""
|
||||
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
|
||||
ragged_batch: Wrapper for the ragged batch.
|
||||
inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, head_size // 2]
|
||||
"""
|
||||
|
||||
q = qkv[:, :self.head_size * self.n_q_heads]
|
||||
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
|
||||
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
|
||||
|
||||
self.kernel(kv_cache, q, k, v, inverse_freqs, ragged_batch.batch_metadata_buffer(),
|
||||
ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ....inference_utils import DtypeEnum
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ... import DSKernelBase
|
||||
|
||||
|
||||
class LinearBlockedKVCopy(DSKernelBase):
|
||||
"""
|
||||
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
|
||||
before copying into a blocked KV cache.
|
||||
"""
|
||||
|
||||
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
|
||||
supported_head_sizes = [64, 128]
|
||||
supported_q_ratios = [1, 2, 4, 5, 8]
|
||||
|
||||
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Args:
|
||||
head_size: The size of the attention head.
|
||||
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
|
||||
"""
|
||||
|
||||
q_ratio = n_q_heads // n_kv_heads
|
||||
|
||||
if head_size not in LinearBlockedKVCopy.supported_head_sizes:
|
||||
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
|
||||
head_size, LinearBlockedKVCopy.supported_head_sizes))
|
||||
|
||||
if q_ratio not in LinearBlockedKVCopy.supported_q_ratios:
|
||||
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
|
||||
q_ratio, LinearBlockedKVCopy.supported_q_ratios))
|
||||
|
||||
if not isinstance(dtype, DtypeEnum):
|
||||
dtype = DtypeEnum(dtype)
|
||||
|
||||
if dtype not in LinearBlockedKVCopy.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
dtype, LinearBlockedKVCopy.supported_dtypes))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.linear_kv_copy
|
||||
self.head_size = head_size
|
||||
self.n_q_heads = n_q_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
|
||||
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
|
||||
"""
|
||||
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
|
||||
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
|
||||
ragged_batch: Wrapper for the ragged batch.
|
||||
"""
|
||||
|
||||
q = qkv[:, :self.head_size * self.n_q_heads]
|
||||
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
|
||||
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
|
||||
|
||||
self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(),
|
||||
ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .logits_gather import *
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "logits_gather.h"
|
||||
|
||||
#define DISPATCH_TO_LOGITS_GATHER(T_TYPE, C_TYPE) \
|
||||
if (all_acts.options().dtype() == torch::T_TYPE) { \
|
||||
launch_logits_gather((C_TYPE*)final_token_acts.data_ptr(), \
|
||||
(const C_TYPE*)all_acts.data_ptr(), \
|
||||
batch_metadata_raw, \
|
||||
seq_metadata_raw, \
|
||||
n_seqs, \
|
||||
embed_dim, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
}
|
||||
|
||||
/*
|
||||
Logits gather will parse the ragged batch data structure and gather only the logits that
|
||||
will be used for token sampling.
|
||||
*/
|
||||
void gather_for_logits(torch::Tensor& final_token_acts,
|
||||
torch::Tensor& all_acts,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata)
|
||||
{
|
||||
const RaggedBatchDescriptor* batch_metadata_raw =
|
||||
reinterpret_cast<const RaggedBatchDescriptor*>(batch_metadata.data_ptr());
|
||||
|
||||
const InflightSeqDescriptor* seq_metadata_raw =
|
||||
reinterpret_cast<const InflightSeqDescriptor*>(seq_metadata.data_ptr());
|
||||
|
||||
const int n_seqs = final_token_acts.size(0);
|
||||
const int embed_dim = final_token_acts.size(1);
|
||||
|
||||
TORCH_CHECK(all_acts.scalar_type() == final_token_acts.scalar_type(),
|
||||
"all_acts and final_token_acts must have the same scalar type");
|
||||
|
||||
DISPATCH_TO_LOGITS_GATHER(kFloat, float)
|
||||
DISPATCH_TO_LOGITS_GATHER(kHalf, half)
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_TO_LOGITS_GATHER(kBFloat16, __nv_bfloat16)
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "logits_gather.cuh"
|
||||
#include "memory_access_utils.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
namespace logits_gather {
|
||||
|
||||
constexpr int granularity = 16;
|
||||
constexpr int threads = 512;
|
||||
|
||||
} // namespace logits_gather
|
||||
|
||||
template <typename T>
|
||||
__global__ void logits_gather_kernel(T* final_token_acts,
|
||||
const T* token_acts,
|
||||
const RaggedBatchDescriptor* ragged_batch,
|
||||
const InflightSeqDescriptor* inflight_batch,
|
||||
const int32_t embed_dim)
|
||||
{
|
||||
constexpr int T_vector = logits_gather::granularity / sizeof(T);
|
||||
|
||||
const int32_t seq_id = blockIdx.y;
|
||||
|
||||
// It's possible we've padded the output Tensor (under CG conditions)
|
||||
if (seq_id >= ragged_batch->n_sequences) return;
|
||||
|
||||
const InflightSeqDescriptor seq = inflight_batch[seq_id];
|
||||
const int final_token_idx = seq.start_idx + seq.n_tokens - 1;
|
||||
|
||||
const int token_offset = final_token_idx * embed_dim;
|
||||
const int thread_offset =
|
||||
threadIdx.x * T_vector + blockIdx.x * logits_gather::threads * T_vector;
|
||||
|
||||
const int final_token_offset = seq_id * embed_dim;
|
||||
|
||||
T reg_buf[T_vector];
|
||||
|
||||
if (thread_offset < embed_dim) {
|
||||
mem_access::load_global<logits_gather::granularity>(
|
||||
reg_buf, token_acts + token_offset + thread_offset);
|
||||
|
||||
mem_access::store_global<logits_gather::granularity>(
|
||||
final_token_acts + final_token_offset + thread_offset, reg_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_logits_gather(T* final_token_acts,
|
||||
const T* all_acts,
|
||||
const RaggedBatchDescriptor* ragged_batch,
|
||||
const InflightSeqDescriptor* inflight_batch,
|
||||
const int32_t n_seqs,
|
||||
const int32_t embed_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr int T_vector = logits_gather::granularity / sizeof(T);
|
||||
constexpr int elems_per_block = logits_gather::threads * T_vector;
|
||||
const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block;
|
||||
|
||||
const dim3 grid(parallel_blocks, n_seqs, 1);
|
||||
const dim3 block(logits_gather::threads, 1, 1);
|
||||
|
||||
logits_gather_kernel<T><<<grid, block, 0, stream>>>(
|
||||
final_token_acts, all_acts, ragged_batch, inflight_batch, embed_dim);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_TYPE(T) \
|
||||
template void launch_logits_gather<T>(T * final_token_acts, \
|
||||
const T* all_acts, \
|
||||
const RaggedBatchDescriptor* ragged_batch, \
|
||||
const InflightSeqDescriptor* inflight_batch, \
|
||||
const int32_t n_seqs, \
|
||||
const int32_t embed_dim, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_FOR_TYPE(float)
|
||||
INSTANTIATE_FOR_TYPE(__half)
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_FOR_TYPE(__nv_bfloat16)
|
||||
#endif
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ds_kernel_utils.h"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void launch_logits_gather(T* final_token_acts,
|
||||
const T* all_acts,
|
||||
const RaggedBatchDescriptor* batch_metadata,
|
||||
const InflightSeqDescriptor* seq_metadata,
|
||||
const int32_t n_seqs,
|
||||
const int32_t embed_dim,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include "logits_gather.cuh"
|
||||
#include "ragged_dtypes.h"
|
||||
|
||||
/*
|
||||
Logits gather will parse the ragged batch data structure and gather only the logits that
|
||||
will be used for token sampling.
|
||||
*/
|
||||
void gather_for_logits(torch::Tensor& final_token_acts,
|
||||
torch::Tensor& all_acts,
|
||||
torch::Tensor& batch_metadata,
|
||||
torch::Tensor& seq_metadata);
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from ... import DSKernelBase
|
||||
from deepspeed.ops.op_builder import RaggedOpsBuilder
|
||||
from ....inference_utils import elem_size
|
||||
from ....ragged import RaggedBatchWrapper
|
||||
|
||||
|
||||
class RaggedLogitsGather(DSKernelBase):
|
||||
"""
|
||||
CUDA Kernel implementation for gather the hidden states of the final token
|
||||
of each sequence. This is used to reduce the cost of the performing the unembedding.
|
||||
"""
|
||||
|
||||
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, model_dim: int, fp_dtype: torch.dtype):
|
||||
"""
|
||||
Parameters:
|
||||
fp_dtype (torch.dtype): Data type for the input/output. Supported values
|
||||
are torch.float16, torch.bfloat16, and torch.float32.
|
||||
"""
|
||||
if fp_dtype not in RaggedLogitsGather.supported_dtypes:
|
||||
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
|
||||
fp_dtype, RaggedLogitsGather.supported_dtypes))
|
||||
|
||||
if elem_size(fp_dtype) * model_dim % 16 != 0:
|
||||
raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(model_dim))
|
||||
|
||||
inf_module = RaggedOpsBuilder().load()
|
||||
self.kernel = inf_module.gather_for_logits
|
||||
|
||||
def __call__(self, final_token_activations: torch.Tensor, all_activations: torch.Tensor,
|
||||
ragged_wrapper: RaggedBatchWrapper) -> torch.Tensor:
|
||||
"""
|
||||
Gather the hidden states of the final token of each sequence from `all_activations` into
|
||||
`final_token_activations`.
|
||||
|
||||
Args:
|
||||
final_token_activations (torch.Tensor): Output tensor of shape [num_seqs, model_dim]
|
||||
all_activations (torch.Tensor): Input tensor of shape [num_tokens, model_dim]
|
||||
ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch.
|
||||
"""
|
||||
|
||||
self.kernel(final_token_activations, all_activations, ragged_wrapper.batch_metadata_buffer(),
|
||||
ragged_wrapper.inflight_seq_descriptors())
|
||||
return final_token_activations
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче