Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Connor Holmes 2023-11-03 15:07:35 -07:00 коммит произвёл GitHub
Родитель 737ef296cd
Коммит 38b41dffa1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
263 изменённых файлов: 19167 добавлений и 33 удалений

56
.github/workflows/nv-a6000.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,56 @@
name: nv-a6000
on:
pull_request:
paths-ignore:
- 'docs/**'
- 'blogs/**'
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
issues: write
jobs:
unit-tests:
runs-on: [self-hosted, nvidia, a6000]
container:
image: nvcr.io/nvidia/pytorch:23.03-py3
ports:
- 80
options: --gpus all --shm-size "8G"
steps:
- uses: actions/checkout@v3
- name: Check container state
run: |
ldd --version
nvcc --version
nvidia-smi
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
git rev-parse --short HEAD
python -m pip install .
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install .[dev,1bit,autotuning]
ds_report
- name: Python environment
run: |
python -m pip list
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12"
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12"

2
.github/workflows/nv-pre-compile-ops.yml поставляемый
Просмотреть файл

@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report

0
.gitmodules поставляемый Normal file
Просмотреть файл

Просмотреть файл

@ -49,6 +49,7 @@ repos:
entry: ./scripts/check-license.py
language: python
files: \.(py|c|cpp|cu|cc|h|hpp|cuh|hip|tr)$
exclude: ^(deepspeed/inference/v2/kernels/ragged_ops/blocked_flash|deepspeed/inference/v2/kernels/cutlass_ops/grouped_gemm)
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0

Просмотреть файл

@ -1,4 +1,6 @@
include *.txt README.md
include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc

Просмотреть файл

@ -153,9 +153,26 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory
def _get_nvml_gpu_id(self, torch_gpu_id):
"""
credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
If the latter isn't set return the same id
"""
# if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
if "CUDA_VISIBLE_DEVICES" in os.environ:
ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
return ids[torch_gpu_id] # remap
else:
return torch_gpu_id
def available_memory(self, device_index=None):
if pynvml:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
if device_index is None:
device_index = self.current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.free
else:

Просмотреть файл

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
enum ActivationType {
GELU = 0,
RELU = 1,
SILU = 2,
GEGLU = 3,
ReGLU = 4,
SiGLU = 5,
IDENTITY = 6,
InvalidType = -1
};

Просмотреть файл

@ -11,6 +11,11 @@ used throughout the codebase.
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
#define DS_HD_INLINE __host__ __device__ __forceinline__
#define DS_D_INLINE __device__ __forceinline__

Просмотреть файл

@ -233,6 +233,60 @@ DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
#endif
}
template <>
DS_D_INLINE int32_t element<ROpType::Add>(const int32_t lhs, const int32_t rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE int32_t element<ROpType::Max>(const int32_t lhs, const int32_t rhs)
{
return (lhs > rhs) ? lhs : rhs;
}
template <>
DS_D_INLINE int32_t element<ROpType::Min>(const int32_t lhs, const int32_t rhs)
{
return (lhs < rhs) ? lhs : rhs;
}
template <>
DS_D_INLINE uint32_t element<ROpType::Add>(const uint32_t lhs, const uint32_t rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE uint32_t element<ROpType::Max>(const uint32_t lhs, const uint32_t rhs)
{
return (lhs > rhs) ? lhs : rhs;
}
template <>
DS_D_INLINE uint32_t element<ROpType::Min>(const uint32_t lhs, const uint32_t rhs)
{
return (lhs < rhs) ? lhs : rhs;
}
template <>
DS_D_INLINE int64_t element<ROpType::Add>(const int64_t lhs, const int64_t rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE int64_t element<ROpType::Max>(const int64_t lhs, const int64_t rhs)
{
return (lhs > rhs) ? lhs : rhs;
}
template <>
DS_D_INLINE int64_t element<ROpType::Min>(const int64_t lhs, const int64_t rhs)
{
return (lhs < rhs) ? lhs : rhs;
}
/*
Reduction initialization primitives
*/
@ -310,6 +364,78 @@ DS_D_INLINE __half2 init<ROpType::Max>()
#endif
}
template <>
DS_D_INLINE int32_t init<ROpType::Add>()
{
return 0;
}
template <>
DS_D_INLINE int32_t init<ROpType::Min>()
{
return 0x7FFFFFFF;
}
template <>
DS_D_INLINE int32_t init<ROpType::Max>()
{
return 0x80000000;
}
template <>
DS_D_INLINE uint32_t init<ROpType::Add>()
{
return 0;
}
template <>
DS_D_INLINE uint32_t init<ROpType::Min>()
{
return 0xFFFFFFFF;
}
template <>
DS_D_INLINE uint32_t init<ROpType::Max>()
{
return 0;
}
template <>
DS_D_INLINE int64_t init<ROpType::Add>()
{
return 0;
}
template <>
DS_D_INLINE int64_t init<ROpType::Min>()
{
return 0x7FFFFFFFFFFFFFFF;
}
template <>
DS_D_INLINE int64_t init<ROpType::Max>()
{
return 0x8000000000000000;
}
template <>
DS_D_INLINE uint64_t init<ROpType::Add>()
{
return 0;
}
template <>
DS_D_INLINE uint64_t init<ROpType::Min>()
{
return 0xFFFFFFFFFFFFFFFF;
}
template <>
DS_D_INLINE uint64_t init<ROpType::Max>()
{
return 0;
}
template <ROpType Op, typename T>
DS_D_INLINE void init(T* data)
{
@ -352,8 +478,8 @@ here (fold is C++17 only and I don't think helps and recursion feels like
huge overkill that harms readability) that would be wonderful.
*/
template <ROpType Op, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
template <typename T, ROpType Op, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@ -361,8 +487,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
}
}
template <ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
template <typename T, ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@ -371,8 +497,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
}
}
template <ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
template <typename T, ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@ -382,8 +508,13 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
}
}
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
template <typename T,
ROpType Op1,
ROpType Op2,
ROpType Op3,
ROpType Op4,
int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@ -403,16 +534,15 @@ the number of warps in the block (which may exceed that
if the block is partitioned or if we do a conservative bound at
compile time).
*/
template <int total_warps, ROpType... Ops>
template <typename T, int total_warps, ROpType... Ops>
DS_D_INLINE void _block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp_arg,
float* data)
T* data)
{
constexpr int elems = sizeof...(Ops);
// Separated for now in case this no longer is true
constexpr int bytes = sizeof(float);
constexpr int bytes = sizeof(T);
// Unused when `partition_size == 1` or total_warps == 1
__shared__ float reduce_buffer[max_warps * elems];
__shared__ T reduce_buffer[max_warps * elems];
#ifdef __HIP_PLATFORM_AMD__
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
@ -422,7 +552,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
#endif
// Always perform warp-scope reduction
_warp<Ops...>(warp_arg, data);
_warp<T, Ops...>(warp_arg, data);
// If max_warps == 1 let's skip the runtime check
if (total_warps != 1) {
@ -447,7 +577,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
init<Ops...>(data);
}
_warp<Ops..., total_warps>(warp_arg, data);
_warp<T, Ops..., total_warps>(warp_arg, data);
#pragma unroll
for (int i = 0; i < elems; i++) {
@ -476,7 +606,7 @@ us to obfuscate the details of the partitioned implementation.
template <ROpType Op, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
{
_block<warp_bound, Op>(tb, warp, &val);
_block<float, warp_bound, Op>(tb, warp, &val);
}
template <ROpType Op1, ROpType Op2, int warp_bound>
@ -486,7 +616,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val2)
{
float data[2] = {val1, val2};
_block<warp_bound, Op1, Op2>(tb, warp, data);
_block<float, warp_bound, Op1, Op2>(tb, warp, data);
val1 = data[0];
val2 = data[1];
}
@ -499,7 +629,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val3)
{
float data[3] = {val1, val2, val3};
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
_block<float, warp_bound, Op1, Op2, Op3>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@ -514,7 +644,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
_block<float, warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@ -531,10 +661,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float& val)
{
if (num_threads <= hw_warp_size) {
_warp<Op, num_threads>(warp, &val);
_warp<float, Op, num_threads>(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
_block<num_warps, Op>(tb, warp, &val);
_block<float, num_warps, Op>(tb, warp, &val);
}
}
@ -547,10 +677,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[2] = {val1, val2};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, num_threads>(warp, data);
_warp<float, Op1, Op2, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
_block<num_warps, Op1, Op2>(tb, warp, data);
_block<float, num_warps, Op1, Op2>(tb, warp, data);
}
val1 = data[0];
@ -567,10 +697,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[3] = {val1, val2, val3};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, Op3, num_threads>(warp, data);
_warp<float, Op1, Op2, Op3, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
_block<num_warps, Op1, Op2, Op3>(tb, warp, data);
_block<float, num_warps, Op1, Op2, Op3>(tb, warp, data);
}
val1 = data[0];
@ -589,10 +719,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[4] = {val1, val2, val3, val4};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
_warp<float, Op1, Op2, Op3, Op4, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
_block<float, num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
}
val1 = data[0];
@ -601,4 +731,48 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
val4 = data[3];
}
/*
Arg-reduce is a specialization of the above. We only support this with a single reduction
parameter. This only works for max/min reductions.
*/
__align__(8) struct IdxReduceResult {
/*
NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits
and the val is the most significant. Changing the order of this declaration
will break the code.
*/
int idx;
float val;
};
template <ROpType Op, int warpBound>
DS_D_INLINE IdxReduceResult
idx_reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float val, int idx)
{
IdxReduceResult res = {idx, val};
// Clear out the nan. This shouldn't be an issue for our initial applications
if (isnan(val)) res.val = init<Op>();
// Can do float compares as integers. By packing the index into the lower bits
// we can just do a single int64 rather than a branch, compare, and select.
// One side benefit of this is that it is by nature a stable algorithm and
// will always bias ties to the higher index.
int64_t* res_as_int = reinterpret_cast<int64_t*>(&res);
// The way floating point compare works is normally to perform a sign comparison
// and if they match, then do a comparison of the rest of the bits as unsigned
// integers. Since we are bundling these, that means for negative values we need
// to reverse the sort order, which we can do with an XOR.
if (val < 0) { *res_as_int ^= 0x7fffffff00000000; }
_block<int64_t, warpBound, Op>(tb, warp, res_as_int);
// Sign bit is preserved, so we can check if we need to invert the mantissa back
if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; }
return res;
}
} // namespace reduce

Просмотреть файл

@ -18,8 +18,8 @@ using __nv_bfloat162 = __half2;
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
constexpr float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}

Просмотреть файл

@ -2,3 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .v2.engine_v2 import InferenceEngineV2
from .v2 import build_hf_engine

Просмотреть файл

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .engine_v2 import InferenceEngineV2
from .engine_factory import build_hf_engine

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from functools import reduce
from typing import Iterable
import torch
from deepspeed.accelerator import get_accelerator
def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor:
shape_size = reduce(lambda x, y: x * y, shape)
if shape_size == 0:
raise ValueError("Cannot create empty tensor with size 0")
return tensor.flatten()[:shape_size].view(shape)
def on_device(method) -> torch.Tensor:
"""
Wraps a method to ensure the returned tensor is on the current device.
"""
def wrapped(self, *args, **kwargs):
tensor = method(self, *args, **kwargs)
if isinstance(tensor, torch.Tensor):
return tensor.to(get_accelerator().current_device()).contiguous()
return tensor
return wrapped

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base_engine import CheckpointEngineBase
from .in_memory_engine import InMemoryModelEngine
from .huggingface_engine import HuggingFaceCheckpointEngine

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from abc import ABC, abstractmethod
from typing import Iterable, Tuple
import torch
#from .huggingface_engine import HuggingFaceCheckpointEngine
MEGATRON = 'megatron'
HUGGINGFACE = 'huggingface'
class CheckpointEngineBase(ABC):
"""
Abstract interface for checkpoint engines to implement.
There is no ``__init__`` method here by design, since the creation of the checkpoint
engine will happen outside the policy/engine code. The tradeoff being made here is
that we will write different frontends for different checkpoint engines, but these
frontends can be tailored to the specific checkpoint engine/model source needs.
"""
@abstractmethod
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
"""
This method should create a generator of tuples of the form (name, parameter) for
all parameters in the model. The name should be the fully qualified name of the
parameter, and the parameter should be a torch.Tensor.
The expected use of a checkpoint engine is the following:
```python
for name, parameter in checkpoint_engine.parameters():
container_map.map_param(name, parameter)
```
For a concrete use example, see ``InferenceV2Policy``.
"""
...

Просмотреть файл

@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import json
import torch
from .base_engine import CheckpointEngineBase
from typing import Iterable, Tuple
from ..logging import inference_logger
class HuggingFaceCheckpointEngine(CheckpointEngineBase):
def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
super().__init__()
from transformers import AutoConfig, GenerationConfig
self.model_name_or_path = model_name_or_path
self.auth_token = auth_token
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
# Define this property here so we can use it in the model implementation
if not hasattr(self.model_config, "max_seq_length"):
self.model_config.max_seq_length = self.model_config.max_position_embeddings
else:
self.model_config.max_seq_length = self.generation_config.max_length
self._all_ckpt_paths = self._fetch_checkpoint_files()
def _fetch_checkpoint_files(self):
"""
Fetch the checkpoint files from the HuggingFace Hub.
"""
# TODO(jeff): for models like llama-2 the user will have to provide an auth `token`,
# currently coming from the ckpt engine init but maybe a catch all kwargs for other
# snapshot download parameters would be more flexible.
# NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
# checkpoint files that may be present. Example of all files in the llama-2-7b
# repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
from huggingface_hub import snapshot_download
if os.path.isdir(self.model_name_or_path):
self._local_checkpoint_dir = self.model_name_or_path
else:
self._local_checkpoint_dir = snapshot_download(self.model_name_or_path,
allow_patterns=[
"*.bin",
"*.json",
"*.pt",
],
revision=None,
token=self.auth_token)
assert os.path.isdir(
self._local_checkpoint_dir
), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint."
model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json")
if not os.path.isfile(model_param_json):
# We don't need any json as all such HF models will have pytorch_model.bin
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')]
else:
param_map = json.load(open(model_param_json, "r"))
# weight_map -> { "lm_head.weight": "pytorch_model-00002-of-00002.bin", ... }
weight_map = param_map["weight_map"]
# unique set of all checkpoint files
all_checkpoint_files = set(weight_map.values())
# get absolute path of all unique checkpoint files
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, f) for f in all_checkpoint_files]
return all_checkpoint_files
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Generator of model parameters (satisfies the CheckpointEngineBase interface).
"""
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = torch.load(checkpoint, map_location='cpu')
param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
yield param_name, param
if __name__ == "__main__":
# To test, add your auth_token here and run `python huggingface_engine.py`
engine = HuggingFaceCheckpointEngine(model_name_or_path="meta-llama/Llama-2-7b-hf",
auth_token="hf_xxxxxxxxxxxxxxxxx")
for name, param in engine.parameters():
print(name, param.shape)

Просмотреть файл

@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Iterable, Tuple
import torch
from .base_engine import CheckpointEngineBase
class InMemoryModelEngine(CheckpointEngineBase):
"""
This "checkpoint" engine uses the existing interface to enable loading parameters into an
inference model from a model already instantiated in memory. In general, this is not the
recommended way to use the inference engine, and should only be used when absolutely necessary.
The primary limitation of this approach is that the model must be fully instantiated in memory.
In a tensor parallel scenario, this means that the model is either replicated many times in host
memory. Currently, it is also recommended to only use this approach for models held in host memory.
In order to free the memory held by this copy of the model, we delete the model in the first call
to `parameters`, so it is not safe to make this call twice.
"""
def __init__(self, model: torch.nn.Module) -> None:
"""
Create virtual checkpoint engine for the provided module.
Args:
model (torch.nn.Module): Model to load parameters from.
"""
super().__init__()
self.model = model
def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
for name, parameter in self.model.named_parameters():
yield name, parameter
del self.model

Просмотреть файл

@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.pydantic_v1 import Field
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig
class DeepSpeedTPConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """
class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
""" Sets parameters for DeepSpeed Inference Engine. """
tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several
GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
"""
state_manager: DSStateManagerConfig = Field({}, alias="manager")
"""
Configuration for managing persistent state
"""

Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import logging
from typing import Any
from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
from .checkpoint import HuggingFaceCheckpointEngine
from .logging import inference_logger
def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO,
random_weights_config: Any = None,
fill_random: bool = False) -> InferenceEngineV2:
"""
Build an InferenceV2 engine for HuggingFace models.
"""
# Set up logging
inference_logger(level=debug_level)
# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)
# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config
# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
from .model_implementations.opt.policy import OPTPolicy
policy = OPTPolicy(checkpoint_engine, model_config)
elif model_config.model_type == "llama":
from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
policy = Llama2Policy(checkpoint_engine, model_config)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
policy = MistralPolicy(checkpoint_engine, model_config)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
return InferenceEngineV2(policy, engine_config)

Просмотреть файл

@ -0,0 +1,217 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from typing import Iterable, Tuple
import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.comm.comm import init_distributed
from .model_implementations import InferenceV2Policy
from .logging import inference_logger
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
from .scheduling_utils import SchedulingError, SchedulingResult
from .config_v2 import RaggedInferenceEngineConfig
INFERENCE_MODEL_TIMER = "model-forward-inference"
class InferenceEngineV2:
_config: RaggedInferenceEngineConfig
"""
Configuration of the inference engine.
"""
#_model: DSInferenceModelBase
"""
Inference model supporting ragged inference.
"""
_state_manager: DSStateManager
"""
Persistent state manager for sequences and KV-cache.
"""
@property
def free_blocks(self) -> int:
"""
Number of free KV blocks.
"""
return self._state_manager.free_blocks
def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
"""
Create the Inference V2 engine.
Arguments:
policy (InferenceV2Policy): Policy for the model implementation. This policy object
will be used to build the model and load the checkpoint associated with it.
engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine.
"""
self._config = engine_config
self._policy = policy
self._base_mp_group = self._initialize_tp_group()
# Build model from policy
inference_logger().info("Building model...")
self._model = self._policy.build_model(self._config, self._base_mp_group)
inference_logger().info("Model built.")
# Create state manager
self._batch = RaggedBatchWrapper(self._config.state_manager)
self._state_manager = DSStateManager(self._config.state_manager,
self._model.kv_cache_config(),
base_mp_group=self._base_mp_group)
self._model.set_state_manager(self._state_manager)
def _initialize_tp_group(self):
"""
Implementation of our TP group initialization.
"""
init_distributed()
local_rank = int(os.getenv("LOCAL_RANK", 0))
get_accelerator().set_device(local_rank)
if local_rank >= self._config.tensor_parallel.tp_size:
raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.")
ranks = list(range(self._config.tensor_parallel.tp_size))
return dist.new_group(ranks=ranks)
def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -> torch.Tensor:
"""
Put a ragged batch onto the inference engine. This will perform one forward and return
a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens
are not calculated.
Arguments:
batch_uids: Iterable of uids for the batch on the host
batch_tokens: Iterable of token tensors for the batch on the host
"""
token_lens = [len(tokens) for tokens in batch_tokens]
schedule_check = self.can_schedule(batch_uids, token_lens)
if schedule_check != SchedulingResult.Success:
raise SchedulingError(schedule_check)
self._batch.clear()
for uid, tokens in zip(batch_uids, batch_tokens):
host_seq_desc = self._state_manager.get_or_create_sequence(uid)
self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
host_seq_desc.pre_forward(tokens.numel())
# We can disable checks since we already validated schedulability.
self._batch.insert_sequence(host_seq_desc, tokens, do_checks=False)
# Send all metadata to the device
self._batch.finalize()
# Prep all data structures for the actual forward (in anticipation of CG in the future)
# and also to amortize some of the costs in a more straightforward way.
self._model.prepare_batch(self._batch)
# Model implementation will pick up in the forward.
logits = self._model.forward(self._batch)
# We return one set of logits per sequence in the batch (saves cost on unembedding)
assert logits.shape[0] == self._batch.current_sequences
for uid in batch_uids:
host_seq_desc = self._state_manager.get_sequence(uid)
host_seq_desc.post_forward() # Updates sequence metadata.
self._model.maybe_free_kv(host_seq_desc)
return logits
def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]:
"""
Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
(this UID may not be recognized by the model yet), this will return the number of tokens
and blocks to reserve for the request.
Arguments:
uid (int): The UID of the sequence (as tracked by the scheduling entity). If
this is a new sequence (with a UID unknown to the inference engine), then
an empty placeholder is created to pass to the occupancy logic.
n_tokens (int): The number of tokens to hypothetically send.
Returns:
Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks
required to schedule the sequence.
"""
seq_desc = self._state_manager.get_sequence(uid)
if seq_desc is None:
if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences):
return (0, 0)
seq_desc = PlaceholderSequenceDescriptor()
req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks)
return (req_tokens, req_blocks)
def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult:
"""
Dry run a batch to determine if it can be scheduled. Placeholder sequences will be
created for any UIDs that are unknown to the inference engine.
Arguments:
uids (Iterable[int]): Iterable of UIDs for the batch
lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths
corresponds to the number of tokens to send in the hypothetical forward; history
tokens will be determined via UID lookup and future tokens are disregarded.
Returns:
bool: True if the batch can be scheduled, False otherwise.
"""
cur_seqs = self._state_manager.n_tracked_sequences
free_blocks = self._state_manager.free_blocks
req_blocks = 0
batch_len = 0
if len(uids) > self._config.state_manager.max_ragged_sequence_count:
# Can only compose a batch from a limited number of sequences
return SchedulingResult.BatchSequenceLimitExceeded
for uid, length in zip(uids, lengths):
seq_desc = self._state_manager.get_sequence(uid)
if seq_desc is None:
cur_seqs += 1
seq_desc = PlaceholderSequenceDescriptor()
sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks)
if sched_len != length:
# We ran out of KV cache
return SchedulingResult.KVCacheLimitExceeded
batch_len += length
free_blocks -= sched_blocks
if cur_seqs > self._config.state_manager.max_tracked_sequences:
# Would run out of tracking metadata
return SchedulingResult.EngineSequenceLimitExceeded
if batch_len > self._config.state_manager.max_ragged_batch_size:
# Would exceed the maximum batch size
return SchedulingResult.BatchTokenLimitExceeded
return SchedulingResult.Success
def flush(self, uid: int) -> None:
"""
Remove all state associated with a sequence from the inference engine.
Arguments:
uid (int): The UID of the sequence to flush.
"""
self._state_manager.flush_sequence(uid)

Просмотреть файл

@ -0,0 +1,105 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Dict
import torch
from enum import Enum, IntEnum
class NormTypeEnum(Enum):
LayerNorm: str = "layer_norm"
RMSNorm: str = "rms_norm"
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
int8 = torch.int8, "torch.int8", "int8"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
ELEM_SIZES: Dict[torch.dtype, int] = {
torch.float16: 2,
torch.bfloat16: 2,
torch.float32: 4,
torch.float64: 8,
torch.int8: 1,
torch.uint8: 1,
torch.int16: 2,
torch.int32: 4,
torch.int64: 8,
torch.bool: 1,
}
class ActivationType(IntEnum):
"""
Types of activations supported by DS-Inference
"""
GELU = 0
RELU = 1
SILU = 2
GEGLU = 3
ReGLU = 4
SiGLU = 5
IDENTITY = 6
InvalidType = -1
def is_gated(act_fn: ActivationType) -> bool:
"""
Return True if the given activation function is gated.
"""
if not isinstance(act_fn, ActivationType):
act_fn = ActivationType(act_fn)
return act_fn in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
def elem_size(dtype: torch.dtype) -> int:
"""
Return size in bytes of the given dtype.
"""
try:
return ELEM_SIZES[dtype]
except KeyError:
raise ValueError("Unknown dtype size for {}".format(dtype))
def ceil_div(a: int, b: int) -> int:
"""
Return ceil(a / b).
"""
return -(-a // b)

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .ds_kernel import DSKernelBase

Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .bias_activations import *
from .blas_kernels import *
from .cuda_layer_norm import *
from .cuda_rms_norm import *
from .gated_activations import *

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .bias_activation import *

Просмотреть файл

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "bias_activation.h"
#include <c10/cuda/CUDAStream.h>
#include "ds_kernel_utils.h"
#ifdef BF16_AVAILABLE
#define DTYPE_SWITCH(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kBFloat16) { \
using scalar_t = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#else
#define DTYPE_SWITCH(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#endif
/*
In-place bias and activation fusion kernel.
*/
void bias_activation(torch::Tensor& activation,
c10::optional<torch::Tensor>& bias,
const int32_t act_type)
{
const ActivationType atype = static_cast<ActivationType>(act_type);
const int32_t rows = activation.size(0);
const int32_t cols = activation.size(1);
TORCH_CHECK(atype == ActivationType::GELU || atype == ActivationType::RELU ||
atype == ActivationType::SILU || atype == ActivationType::IDENTITY,
"Unsupported activation type for BiasActivation");
TORCH_CHECK(activation.dim() == 2, "BiasActivation only supports 2D activation tensors");
DTYPE_SWITCH(activation.scalar_type(), [&] {
scalar_t* activation_ptr = reinterpret_cast<scalar_t*>(activation.data_ptr());
const scalar_t* bias_ptr;
if (bias.has_value()) {
TORCH_CHECK(activation.scalar_type() == bias.value().scalar_type(),
"BiasActivation activation and bias must have same dtype");
bias_ptr = reinterpret_cast<const scalar_t*>(bias.value().data_ptr());
} else {
bias_ptr = nullptr;
}
if (atype == ActivationType::IDENTITY && bias_ptr == nullptr) { return; }
launch_bias_activation<scalar_t>(
activation_ptr, bias_ptr, rows, cols, atype, c10::cuda::getCurrentCUDAStream());
});
}

Просмотреть файл

@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <cassert>
#include "activation_type.h"
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
// Default activation function will error out
template <ActivationType ActType>
DS_D_INLINE float act_fn(float val);
template <>
DS_D_INLINE float act_fn<ActivationType::IDENTITY>(float val)
{
return val;
}
template <>
DS_D_INLINE float act_fn<ActivationType::RELU>(float val)
{
return val > 0.0f ? val : 0.0f;
}
template <>
DS_D_INLINE float act_fn<ActivationType::GELU>(float val)
{
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
constexpr float mul_param = 0.044715f;
return val * 0.5f * (1.0f + tanhf(sqrt_param * (val + mul_param * val * val * val)));
}
template <>
DS_D_INLINE float act_fn<ActivationType::SILU>(float val)
{
return val / (1.0f + expf(-val));
}
namespace bias_act {
constexpr int access_size = 16;
constexpr int threads = 512;
constexpr int unroll = 4;
} // namespace bias_act
template <typename T, ActivationType ActType>
__global__ void bias_activation_kernel(T* activation,
const T* bias,
const int32_t rows,
const int32_t cols)
{
constexpr int vector_T = bias_act::access_size / sizeof(T);
const int32_t thread_offset = threadIdx.x * vector_T;
const int32_t block_offset = blockIdx.x * vector_T * bias_act::unroll * bias_act::threads;
const int32_t base_offset = block_offset + thread_offset;
const int32_t thread_stride = bias_act::threads * vector_T;
#pragma unroll
for (int i = 0; i < bias_act::unroll; i++) {
const int32_t iter_offset = base_offset + i * thread_stride;
const int32_t row = iter_offset / cols;
T buffer[vector_T];
T bias_buffer[vector_T];
if (row < rows) {
const int32_t col = iter_offset % cols;
mem_access::load_global<bias_act::access_size>(buffer, activation + iter_offset);
mem_access::load_global<bias_act::access_size>(
bias_buffer, bias + col, bias != nullptr);
#pragma unroll
for (int j = 0; j < vector_T; j++) {
float val =
conversion::to<float>(buffer[j]) + conversion::to<float>(bias_buffer[j]);
buffer[j] = conversion::to<T>(act_fn<ActType>(val));
}
mem_access::store_global<bias_act::access_size>(activation + iter_offset, buffer);
}
}
}
#define ACT_TYPE_SWITCH(ACT_TYPE, ...) \
if (ACT_TYPE == ActivationType::IDENTITY) { \
constexpr ActivationType act_fn_t = ActivationType::IDENTITY; \
return __VA_ARGS__(); \
} else if (ACT_TYPE == ActivationType::RELU) { \
constexpr ActivationType act_fn_t = ActivationType::RELU; \
return __VA_ARGS__(); \
} else if (ACT_TYPE == ActivationType::GELU) { \
constexpr ActivationType act_fn_t = ActivationType::GELU; \
return __VA_ARGS__(); \
} else if (ACT_TYPE == ActivationType::SILU) { \
constexpr ActivationType act_fn_t = ActivationType::SILU; \
return __VA_ARGS__(); \
} else { \
assert(false); \
}
template <typename T>
void launch_bias_activation(T* activation,
const T* bias,
const int32_t n_rows,
const int32_t n_cols,
const ActivationType activation_type,
cudaStream_t stream)
{
constexpr int32_t elems_per_block =
bias_act::threads * bias_act::unroll * bias_act::access_size / sizeof(T);
const int32_t total_elems = n_rows * n_cols;
const int32_t blocks = (total_elems + elems_per_block - 1) / elems_per_block;
const dim3 grid(blocks);
const dim3 block(bias_act::threads);
ACT_TYPE_SWITCH(activation_type, [&] {
bias_activation_kernel<T, act_fn_t>
<<<grid, block, 0, stream>>>(activation, bias, n_rows, n_cols);
});
}
#define INSTANTIATE_FOR_T(T) \
template void launch_bias_activation<T>( \
T*, const T*, const int32_t, const int32_t, const ActivationType, cudaStream_t);
INSTANTIATE_FOR_T(__half);
#ifdef BF16_AVAILABLE
INSTANTIATE_FOR_T(__nv_bfloat16);
#endif

Просмотреть файл

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "activation_type.h"
template <typename T>
void launch_bias_activation(T* activation,
const T* bias,
const int32_t n_rows,
const int32_t n_cols,
const ActivationType activation_type,
cudaStream_t stream);
void bias_activation(torch::Tensor& activation,
c10::optional<torch::Tensor>& bias,
const int32_t activation_type);

Просмотреть файл

@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Optional
import torch
from ....inference_utils import ActivationType, DtypeEnum
from deepspeed.ops.op_builder import InferenceCoreBuilder
from ... import DSKernelBase
class CUDABiasActivation(DSKernelBase):
"""
CUDA implementation of bias activation kernel. This kernel should be deprecated once
we are fusing the bias activation into the linear kernel in all scenarios.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_act_fns = [ActivationType.IDENTITY, ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]
def __init__(self, channels: int, dtype: DtypeEnum, act_fn: ActivationType) -> None:
"""
Compile and validate for the fused bias-activation kernel.
Parameters:
channels (int): Number of channels to expect in the activation.
dtype (torch.dtype): Data type for the input/output. Supported values
are DtypeEnum.fp16 and DtypeEnum.bf16.
act_fn (ActivationType): Activation function to use. Only IDENTITY, GELU, RELU, and SILU are supported.
"""
if channels % 8 != 0:
raise ValueError("channels must be divisible by 8")
if DtypeEnum(dtype) not in CUDABiasActivation.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
dtype, CUDABiasActivation.supported_dtypes))
act_fn = ActivationType(act_fn)
if act_fn not in CUDABiasActivation.supported_act_fns:
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
act_fn, CUDABiasActivation.supported_act_fns))
inf_module = InferenceCoreBuilder().load()
self.kernel = inf_module.bias_activation
self.act_fn = act_fn
def __call__(self, activation: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Add an optional bias and perform the non-linear activation function.
Parameters:
activation (torch.Tensor): Input tensor of shape [tokens, channels]
bias (torch.Tensor): Optional bias tensor of shape [channels]
Returns:
activation that has been updated in-place
"""
self.kernel(activation, bias, self.act_fn.value)

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .blas_linear import *

Просмотреть файл

@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdio>
#include "blas_utils.h"
#define DISPATCH_BLAS_MATMUL(T_TYPE, C_TYPE) \
if (output.options().dtype() == torch::T_TYPE) { \
blas_gemm_ex(output.data_ptr(), \
(const void*)weights.data_ptr(), \
(const void*)hidden_states.data_ptr(), \
m, \
n, \
k, \
lda, \
ldb, \
ldc, \
trans_a, \
trans_b, \
&alpha, \
&beta, \
C_TYPE); \
}
void blas_linear(at::Tensor& output, at::Tensor& hidden_states, at::Tensor& weights)
{
/*
Expected shape: output([total_tokens_across_dims], out_neurons)
hidden_states([total_tokens_across_dims], in_neurons)
weights(out_neurons, in_neurons)
We are going to assume contiguous for the above shapes.
The shapes are going to get messed with a little internally to handle column-major
GEMMs.
*/
// Number of tokens is N (since the GEMM output is column-major but our Tensor
// is row-major, we need to transpose the shapes)
const int n = output.numel() / output.size(-1);
const int k = weights.size(1);
const int m = weights.size(0);
// A strides
const bool trans_a = weights.stride(1) == 1;
const int lda = (trans_a) ? weights.stride(0) : weights.stride(1);
// B strides
const bool trans_b = hidden_states.stride(-1) != 1;
const int ldb = (trans_b) ? hidden_states.stride(-1) : hidden_states.stride(-2);
// C strides
const int ldc = output.stride(-2);
const float alpha = 1.0f;
const float beta = 0.0f;
TORCH_CHECK(output.scalar_type() == hidden_states.scalar_type(),
"Output and hidden states must have the same scalar type");
TORCH_CHECK(output.scalar_type() == weights.scalar_type(),
"Output and weights must have the same scalar type");
// Dispatch the datatypes
DISPATCH_BLAS_MATMUL(kFloat, BlasType::FP32);
DISPATCH_BLAS_MATMUL(kHalf, BlasType::FP16);
#ifdef BF16_AVAILABLE
DISPATCH_BLAS_MATMUL(kBFloat16, BlasType::BF16);
#endif
}
#define DISPATCH_4D_BLAS(T_TYPE, C_TYPE) \
if (C.options().dtype() == torch::T_TYPE) { \
blas_strided_batched_gemm(C.data_ptr(), \
(const void*)A.data_ptr(), \
(const void*)B.data_ptr(), \
m, \
n, \
k, \
lda, \
ldb, \
ldc, \
trans_a, \
trans_b, \
&alpha, \
&beta, \
stride_a, \
stride_b, \
stride_c, \
batch, \
C_TYPE); \
}
void blas_4d_matmul(at::Tensor& C, at::Tensor& B, at::Tensor& A)
{
/*
C shape: (batch_size, N, M)
A shape: (batch_size, N, K)
B shape: (batch_size, K, M)
*/
const int n = C.size(-2);
const int k = C.size(-1);
const int m = B.size(-1);
// A strides
const bool trans_a = A.stride(-1) == 1;
const int lda = (trans_a) ? A.stride(-2) : A.stride(-1);
const int stride_a = A.stride(-3);
// B strides
const bool trans_b = B.stride(-1) != 1;
const int ldb = (trans_b) ? B.stride(-1) : B.stride(-2);
const int stride_b = B.stride(-3);
// C strides
const int ldc = C.stride(-2);
const int stride_c = C.stride(-3);
const float alpha = 1.0f;
const float beta = 0.0f;
const int batch = C.numel() / (n * m);
// Dispatch the datatypes
DISPATCH_4D_BLAS(kFloat, BlasType::FP32);
DISPATCH_4D_BLAS(kHalf, BlasType::FP16);
#ifdef BF16_AVAILABLE
DISPATCH_4D_BLAS(kBFloat16, BlasType::BF16);
#endif
}
void create_handle() { BlasContext::getInstance().get_handle(); }

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ....inference_utils import DtypeEnum
from deepspeed.ops.op_builder import InferenceCoreBuilder
from ... import DSKernelBase
class BlasLibLinear(DSKernelBase):
"""
Wrapper around the BLAS matmul kernel for FP16/BF16/FP32 for CUDA/RoCM.
Performs z = x @ y
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]
def __init__(self, fp_dtype: DtypeEnum):
"""
Parameters:
fp_dtype (torch.dtype): Data type for the input/output. Supported values
are torch.float16, torch.bfloat16, and torch.float32.
"""
fp_dtype = DtypeEnum(fp_dtype)
if fp_dtype not in BlasLibLinear.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, BlasLibLinear.supported_dtypes))
self.inf_module = InferenceCoreBuilder().load()
self.inf_module.create_handle()
self.kernel = self.inf_module.blas_linear
def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""
Matmul kernel as implemented by platform BLAS library. The input must be 2D or larger. If
n-dimensional, the leading dimensions are folded into each other:
2D: m = x.size(0)
3D: m = x.size(0) * x.size(1)
4D: m = x.size(0) * x.size(1) * x.size(2) (etc...)
All inputs should be contiguous.
Parameters:
output (torch.Tensor): Output tensor. Shape is of [*, out_features]
hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features]
weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features]
Returns:
z (torch.Tensor): Output tensor. Shape is of [m, n]
"""
self.kernel(output, hidden_states, weights)
return output

Просмотреть файл

@ -0,0 +1,275 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#include <mma.h>
#endif
#include <stdio.h>
#include <iostream>
#include <stdexcept>
class BlasContext {
/*
Slim wrapper for managing the lifetime of the platform's BLAS handle. This should
be hipified for ROCm.
*/
public:
BlasContext()
{
if (cublasCreate(&_handle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
#ifndef __HIP_PLATFORM_HCC__
cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH);
#endif
}
virtual ~BlasContext() { cublasDestroy(_handle); }
static BlasContext& getInstance()
{
// Should always access the singleton through this function.
static BlasContext _instance;
return _instance;
}
cublasHandle_t get_handle() const { return _handle; }
private:
cublasHandle_t _handle;
};
enum class BlasType { FP32, FP16, BF16 };
#ifdef __HIP_PLATFORM_HCC__
rocblas_operation get_trans_op(bool do_trans)
{
return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none;
}
rocblas_datatype get_datatype(BlasType type)
{
switch (type) {
case BlasType::FP32: return rocblas_datatype_f32_r;
case BlasType::FP16: return rocblas_datatype_f16_r;
case BlasType::BF16: return rocblas_datatype_bf16_r;
default: throw std::runtime_error("Unsupported BlasType");
}
}
#else
cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T : CUBLAS_OP_N; }
cublasDataType_t get_datatype(BlasType type)
{
switch (type) {
case BlasType::FP32: return CUDA_R_32F;
case BlasType::FP16: return CUDA_R_16F;
case BlasType::BF16: return CUDA_R_16BF;
default: throw std::runtime_error("Unsupported BlasType");
}
}
#endif
int blas_gemm_ex(void* C,
const void* A,
const void* B,
int m,
int n,
int k,
int lda,
int ldb,
int ldc,
bool transa,
bool transb,
const float* alpha,
const float* beta,
BlasType type)
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);
rocblas_datatype_t abc_type = get_datatype(type);
rocblas_status status = rocblas_gemm_ex(BlasContext::getInstance().get_handle(),
transa_op,
transb_op,
m,
n,
k,
(const void*)alpha,
A,
abc_type,
lda,
B,
abc_type,
ldb,
(const void*)beta,
C,
abc_type,
ldc,
C,
abc_type,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
cublasOperation_t transa_op = get_trans_op(transa);
cublasOperation_t transb_op = get_trans_op(transb);
cublasDataType_t abc_type = get_datatype(type);
cublasStatus_t status = cublasGemmEx(BlasContext::getInstance().get_handle(),
transa_op,
transb_op,
m,
n,
k,
(const void*)alpha,
A,
abc_type,
lda,
B,
abc_type,
ldb,
(const void*)beta,
C,
abc_type,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int blas_strided_batched_gemm(void* C,
const void* A,
const void* B,
int m,
int n,
int k,
int lda,
int ldb,
int ldc,
bool transa,
bool transb,
const float* alpha,
const float* beta,
int stride_A,
int stride_B,
int stride_C,
int batch,
BlasType type)
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);
rocblas_datatype_t abc_type = get_datatype(type);
rocblas_status status =
rocblas_gemm_strided_batched_ex(BlasContext::getInstance()::get_handle(),
transa_op,
transb_op,
m,
n,
k,
(const void*)alpha,
A,
abc_type,
lda,
stride_A,
B,
abc_type,
ldb,
stride_B,
(const void*)beta,
C,
abc_type,
ldc,
stride_C,
C,
abc_type,
ldc,
stride_C,
batch,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
cublasOperation_t transa_op = get_trans_op(transa);
cublasOperation_t transb_op = get_trans_op(transb);
cublasDataType_t abc_type = get_datatype(type);
cublasStatus_t status = cublasGemmStridedBatchedEx(BlasContext::getInstance().get_handle(),
transa_op,
transb_op,
m,
n,
k,
(const void*)alpha,
A,
abc_type,
lda,
stride_A,
B,
abc_type,
ldb,
stride_B,
(const void*)beta,
C,
abc_type,
ldc,
stride_C,
batch,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}

Просмотреть файл

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "bias_activation.h"
#include "blas.h"
#include "gated_activation_kernels.h"
#include "layer_norm.h"
#include "rms_norm.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// bias_activation.h
m.def("bias_activation", &bias_activation, "DeepSpeed bias activation in CUDA");
// layer_norm.h
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm in CUDA");
m.def("pre_layer_norm", &ds_pre_layer_norm, "DeepSpeed pre layer norm in CUDA");
m.def("post_layer_norm", &ds_post_layer_norm, "DeepSpeed pre layer norm in CUDA");
// blas.h
m.def("blas_linear", &blas_linear, "Linear implemented by vendor BLAS");
m.def("blas_4d_matmul", &blas_4d_matmul, "4D matmul implemented by vendor BLAS");
m.def("create_handle", &create_handle, "Create a handle for vendor BLAS");
// gated_activation_kernels.h
m.def("gated_activation", &ds_gated_activation, "DeepSpeed gated activation in CUDA");
// rms_norm.h
m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");
}

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .cuda_ln import *
from .cuda_post_ln import *
from .cuda_pre_ln import *

Просмотреть файл

@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from ....inference_utils import elem_size
from deepspeed.ops.op_builder import InferenceCoreBuilder
class CUDAFPLNBase(DSKernelBase):
"""
Base class for CUDA LN kernels. They all same the same validation logic,
so we can share it here.
"""
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
"""
Parameters:
channels (int): Number of channels in the input tensor. Must be divisible to align
to 16 bytes.
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
are torch.float16, torch.bfloat16, and torch.float32.
"""
if fp_dtype not in CUDAFPLNBase.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, CUDAFPLNBase.supported_dtypes))
if elem_size(fp_dtype) * channels % 16 != 0:
raise ValueError("channels must be divisible by 16 bytes")
self.inf_module = InferenceCoreBuilder().load()
self.epsilon = epsilon

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .cuda_fp_ln_base import CUDAFPLNBase
class CUDAFPLN(CUDAFPLNBase):
"""
Floating point layer norm kernel for CUDA/RoCM.
Performs: z = ln(x)
"""
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor,
beta: torch.Tensor) -> torch.Tensor:
"""
output_z may alias input_x directly. All Tensors should have the same shape.
Parameters:
output_z (torch.Tensor): Output tensor.
input_x (torch.Tensor): Input tensor.
gamma (torch.Tensor): Gamma tensor.
beta (torch.Tensor): Beta tensor.
"""
self.inf_module.layer_norm(output_z, input_x, gamma, beta, self.epsilon)
return output_z

Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .cuda_fp_ln_base import CUDAFPLNBase
class CUDAFPPostLN(CUDAFPLNBase):
"""
Floating point post-LayerNorm kernel for CUDA/RoCM.
Performs: z = ln(x + y)
"""
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, input_y: torch.Tensor, gamma: torch.Tensor,
beta: torch.Tensor) -> torch.Tensor:
"""
Either input_x or input_y can alias output_z.
Parameters:
output_z (torch.Tensor): Output tensor.
input_x (torch.Tensor): Input tensor.
input_y (torch.Tensor): Input tensor.
gamma (torch.Tensor): Gamma tensor.
beta (torch.Tensor): Beta tensor.
Returns:
output (torch.Tensor): Output tensor.
"""
self.inf_module.post_layer_norm(output_z, input_x, input_y, gamma, beta, self.epsilon)
return output_z

Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Tuple
import torch
from .cuda_fp_ln_base import CUDAFPLNBase
class CUDAFPPreLN(CUDAFPLNBase):
"""
Floating point pre-LayerNorm kernel for CUDA/RoCM.
Performs: z_res = x_res + y_hid
z_hid = ln(z_hid)
"""
def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
gamma: torch.Tensor, beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
z_res can alias x_res. All non-parameter input/output tensors
must have the same shape. z_hid can alias y_hid.
Parameters:
z_res (torch.Tensor): Output residual.
z_hid (torch.Tensor): Output hidden states.
x_res (torch.Tensor): Input residual.
y_hid (torch.Tensor): Input hidden states.
gamma (torch.Tensor): Gamma tensor.
beta (torch.Tensor): Beta tensor.
Returns:
output (torch.Tensor): Output tensor.
"""
self.inf_module.pre_layer_norm(z_res, z_hid, x_res, y_hid, gamma, beta, self.epsilon)
return z_res, z_hid

Просмотреть файл

@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "layer_norm.h"
#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \
if (input.options().dtype() == torch::T_TYPE) { \
launch_fused_ln((C_TYPE*)output.data_ptr(), \
(const C_TYPE*)input.data_ptr(), \
(const C_TYPE*)gamma.data_ptr(), \
(const C_TYPE*)beta.data_ptr(), \
epsilon, \
rows, \
elems_per_row, \
at::cuda::getCurrentCUDAStream()); \
}
void ds_layer_norm(at::Tensor& output,
at::Tensor& input,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
bool ragged_input = input.dim() == 2;
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
DISPATCH_LAYER_NORM(kFloat, float);
DISPATCH_LAYER_NORM(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16);
#endif
}
#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
if (input.options().dtype() == torch::T_TYPE) { \
launch_fused_post_ln((C_TYPE*)output.data_ptr(), \
(const C_TYPE*)input.data_ptr(), \
(const C_TYPE*)residual.data_ptr(), \
(const C_TYPE*)gamma.data_ptr(), \
(const C_TYPE*)beta.data_ptr(), \
epsilon, \
rows, \
elems_per_row, \
at::cuda::getCurrentCUDAStream()); \
}
void ds_post_layer_norm(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
bool ragged_input = input.dim() == 2;
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float);
DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
#endif
}
#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
if (input.options().dtype() == torch::T_TYPE) { \
launch_fused_pre_ln((C_TYPE*)norm_output.data_ptr(), \
(C_TYPE*)res_output.data_ptr(), \
(const C_TYPE*)input.data_ptr(), \
(const C_TYPE*)residual.data_ptr(), \
(const C_TYPE*)gamma.data_ptr(), \
(const C_TYPE*)beta.data_ptr(), \
epsilon, \
rows, \
elems_per_row, \
at::cuda::getCurrentCUDAStream()); \
}
void ds_pre_layer_norm(at::Tensor& res_output,
at::Tensor& norm_output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
bool ragged_input = input.dim() == 2;
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float);
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
#endif
}

Просмотреть файл

@ -0,0 +1,490 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
using rop = reduce::ROpType;
namespace ln {
constexpr int granularity = 16;
} // namespace ln
/*
Regular layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
Args:
output: buffer for output data
vals: buffer for input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads>
__global__ void fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = blockDim.x * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
T local_buffer[unRoll * T_per_load];
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
sum = reduce::element<rop::Add>(sum, vals_up_cast);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float val = conversion::to<float>(iteration_buffer[j]);
val = (val - mean) * denom;
val =
val * conversion::to<float>(gamma_local[j]) + conversion::to<float>(beta_local[j]);
iteration_buffer[j] = conversion::to<T>(val);
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \
fused_ln<T, unRollFactor, threadsPerGroup, maxThreads> \
<<<grid, block, 0, stream>>>(output, vals, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
#define INSTANTIATE_FUSED_LN(T) \
template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
INSTANTIATE_FUSED_LN(__half);
#ifdef BF16_AVAILABLE
INSTANTIATE_FUSED_LN(__nv_bfloat16);
#endif
INSTANTIATE_FUSED_LN(float);
/*
Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
need to be fused into compute-bound producer operations.
Args:
output: buffer for output data
res_output: output of residual addition
vals: buffer for input data
residual: residual data
bias: bias of of input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
Template arg:
StoreResidual: controls whether the residual calculation is stored
or not. When set to false, the input `res_output` is unused.
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads, bool preLnResidual>
__global__ void fused_residual_ln(T* output,
T* res_output,
const T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
const T* residual_base = residual + base_offset;
T local_buffer[unRoll * T_per_load];
// Unlike a vanilla layernorm, since we're fusing the two adds as well
// an inner unRoll seems to be less valuable. If anything, a double unRoll
// makes the most sense if we find we are having performance issues.
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(residual_buffer,
residual_base + i * stride,
thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
float res_up_cast = conversion::to<float>(residual_buffer[j]);
vals_up_cast += res_up_cast;
sum = reduce::element<rop::Add>(sum, vals_up_cast);
iteration_buffer[j] = conversion::to<T>(vals_up_cast);
}
if (preLnResidual && (thread_offset + i * stride < elems_per_row)) {
mem_access::store_global<ln::granularity>(res_output + base_offset + i * stride,
iteration_buffer);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float val = conversion::to<float>(iteration_buffer[j]);
val = (val - mean) * denom;
val =
val * conversion::to<float>(gamma_local[j]) + conversion::to<float>(beta_local[j]);
iteration_buffer[j] = conversion::to<T>(val);
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \
fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, false> \
<<<grid, block, 0, stream>>>( \
output, nullptr, vals, residual, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_post_ln(T* output,
const T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \
fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, true> \
<<<grid, block, 0, stream>>>( \
norm_output, res_output, vals, residual, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_pre_ln(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads);
}
}
#define INSTANTIATE_RES_LN(T) \
template void launch_fused_post_ln<T>( \
T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
#define INSTANTIATE_PRE_LN_RES(T) \
template void launch_fused_pre_ln<T>( \
T*, T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
INSTANTIATE_RES_LN(__half);
INSTANTIATE_RES_LN(float);
#ifdef BF16_AVAILABLE
INSTANTIATE_RES_LN(__nv_bfloat16);
#endif
INSTANTIATE_PRE_LN_RES(__half);
INSTANTIATE_PRE_LN_RES(float);
#ifdef BF16_AVAILABLE
INSTANTIATE_PRE_LN_RES(__nv_bfloat16);
#endif

Просмотреть файл

@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "ds_kernel_utils.h"
/*
Kernel launch methods for layer norm variants.
*/
template <typename T>
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
template <typename T>
void launch_fused_post_ln(T* output,
const T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
template <typename T>
void launch_fused_pre_ln(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
void ds_layer_norm(at::Tensor& output,
at::Tensor& input,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon);
void ds_post_layer_norm(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon);
void ds_pre_layer_norm(at::Tensor& res_output,
at::Tensor& norm_output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon);

Просмотреть файл

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .rms_norm import CUDARMSNorm
from .rms_pre_norm import CUDARMSPreNorm

Просмотреть файл

@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "rms_norm.h"
#ifdef BF16_AVAILABLE
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using scalar_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kBFloat16) { \
using scalar_t = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#else
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using scalar_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#endif
void rms_norm(torch::Tensor& norm_output,
torch::Tensor& norm_input,
torch::Tensor& gamma,
float epsilon)
{
TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
"norm_output and norm_input should have the same data type");
TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
"norm_output and gamma should have the same data type");
const int32_t rows = norm_input.size(0);
const int32_t cols = norm_input.size(1);
TORCH_CHECK(norm_output.size(0) == rows,
"norm_output and norm_input should have the same first dimension");
TORCH_CHECK(norm_output.size(1) == cols,
"norm_output and norm_input should have the same second dimension");
DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
scalar_t* norm_output_ptr = reinterpret_cast<scalar_t*>(norm_output.data_ptr());
scalar_t* norm_input_ptr = reinterpret_cast<scalar_t*>(norm_input.data_ptr());
scalar_t* gamma_ptr = reinterpret_cast<scalar_t*>(gamma.data_ptr());
scalar_t* null_t = nullptr;
launch_rms_norm(norm_output_ptr,
null_t,
norm_input_ptr,
null_t,
gamma_ptr,
epsilon,
rows,
cols,
at::cuda::getCurrentCUDAStream());
});
}
void rms_pre_norm(torch::Tensor& norm_output,
torch::Tensor& residual_output,
torch::Tensor& norm_input,
torch::Tensor& residual_input,
torch::Tensor& gamma,
float epsilon)
{
TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
"norm_output and norm_input should have the same data type");
TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
"norm_output and gamma should have the same data type");
const int32_t rows = norm_input.size(0);
const int32_t cols = norm_input.size(1);
TORCH_CHECK(norm_output.size(0) == rows,
"norm_output and norm_input should have the same first dimension");
TORCH_CHECK(norm_output.size(1) == cols,
"norm_output and norm_input should have the same second dimension");
TORCH_CHECK(residual_output.size(0) == rows,
"residual_output and norm_input should have the same first dimension");
TORCH_CHECK(residual_output.size(1) == cols,
"residual_output and norm_input should have the same second dimension");
TORCH_CHECK(residual_input.size(0) == rows,
"residual_input and norm_input should have the same first dimension");
TORCH_CHECK(residual_input.size(1) == cols,
"residual_input and norm_input should have the same second dimension");
DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
scalar_t* norm_output_ptr = reinterpret_cast<scalar_t*>(norm_output.data_ptr());
scalar_t* residual_output_ptr = reinterpret_cast<scalar_t*>(residual_output.data_ptr());
const scalar_t* norm_input_ptr = reinterpret_cast<const scalar_t*>(norm_input.data_ptr());
const scalar_t* residual_input_ptr =
reinterpret_cast<const scalar_t*>(residual_input.data_ptr());
const scalar_t* gamma_ptr = reinterpret_cast<const scalar_t*>(gamma.data_ptr());
launch_rms_norm(norm_output_ptr,
residual_output_ptr,
norm_input_ptr,
residual_input_ptr,
gamma_ptr,
epsilon,
rows,
cols,
at::cuda::getCurrentCUDAStream());
});
}

Просмотреть файл

@ -0,0 +1,262 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
using rop = reduce::ROpType;
namespace rms {
constexpr int granularity = 16;
} // namespace rms
template <typename T, int UNROLL, int threadsPerGroup, int maxThreads>
__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row)
{
constexpr int T_per_load = rms::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = blockDim.x * T_per_load;
float var_sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
T local_buffer[UNROLL * T_per_load];
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
T* iteration_buffer = local_buffer + (i * T_per_load);
mem_access::load_global<rms::granularity>(iteration_buffer,
input_base + (i * stride),
thread_offset + (i * stride) < elems_per_row);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
float up_cast = conversion::to<float>(iteration_buffer[j]);
float sq_val = up_cast * up_cast;
var_sum = reduce::element<rop::Add, float>(var_sum, sq_val);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, var_sum);
const float var = var_sum / elems_per_row;
const T denom = conversion::to<T>(__frsqrt_rn(var + epsilon));
T* block_output = output + block_offset;
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
T* iteration_buffer = local_buffer + (i * T_per_load);
const int iter_idx = i * stride + thread_offset;
const bool do_loads = (iter_idx < elems_per_row);
T gamma_local[T_per_load];
mem_access::load_global<rms::granularity>(gamma_local, gamma + iter_idx, do_loads);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] *= denom;
iteration_buffer[j] *= gamma_local[j];
}
if (do_loads) {
mem_access::store_global<rms::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
template <typename T, int UNROLL, int threadsPerGroup, int maxThreads>
__global__ void pre_rms_norm(T* output,
T* res_out,
const T* vals,
const T* residual,
const T* gamma,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = rms::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = blockDim.x * T_per_load;
float var_sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
const T* residual_base = residual + base_offset;
T* res_output = res_out + base_offset;
T local_buffer[UNROLL * T_per_load];
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
T* iteration_buffer = local_buffer + (i * T_per_load);
T residual_buffer[T_per_load];
const int iter_offset = i * stride + thread_offset;
const bool do_loads = (iter_offset < elems_per_row);
mem_access::load_global<rms::granularity>(
iteration_buffer, input_base + (i * stride), do_loads);
mem_access::load_global<rms::granularity>(
residual_buffer, residual_base + (i * stride), do_loads);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] += residual_buffer[j];
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
var_sum = reduce::element<rop::Add, float>(var_sum, vals_up_cast * vals_up_cast);
}
if (do_loads) {
mem_access::store_global<rms::granularity>(res_output + i * stride, iteration_buffer);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, var_sum);
const float var = var_sum / elems_per_row;
const T denom = conversion::to<T>(__frsqrt_rn(var + epsilon));
T* block_output = output + block_offset;
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
T* iteration_buffer = local_buffer + (i * T_per_load);
const int iter_idx = i * stride + thread_offset;
const bool do_loads = (iter_idx < elems_per_row);
T gamma_local[T_per_load];
mem_access::load_global<rms::granularity>(gamma_local, gamma + iter_idx, do_loads);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] *= denom;
iteration_buffer[j] *= gamma_local[j];
}
if (do_loads) {
mem_access::store_global<rms::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
rms_norm<T, UNROLL, threadsPerGroup, maxThreads> \
<<<grid, block, 0, stream>>>(norm_output, vals, gamma, epsilon, elems_per_row);
#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
pre_rms_norm<T, UNROLL, threadsPerGroup, maxThreads><<<grid, block, 0, stream>>>( \
norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row);
#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
if (pre_norm) { \
LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
} else { \
LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
}
template <typename T>
void launch_rms_norm(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* gamma,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = rms::granularity / sizeof(T);
constexpr int maxThreads = 256;
constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threads_per_group, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threads_per_group * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
bool pre_norm = (residual == nullptr) ? false : true;
if (is_subblock_schedule) {
// <=128
if (threads_per_group == 1) {
LAUNCH_ALL_RMS_NORM(1, 1, maxThreads);
} else if (threads_per_group == 2) {
LAUNCH_ALL_RMS_NORM(1, 2, maxThreads);
} else if (threads_per_group == 4) {
LAUNCH_ALL_RMS_NORM(1, 4, maxThreads);
} else if (threads_per_group == 8) {
LAUNCH_ALL_RMS_NORM(1, 8, maxThreads);
} else if (threads_per_group == 16) {
LAUNCH_ALL_RMS_NORM(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads);
}
}
#define INSTANTIATE_LAUNCH_RMS_NORM(T) \
template void launch_rms_norm<T>(T * norm_output, \
T * res_output, \
const T* vals, \
const T* residual, \
const T* gamma, \
float epsilon, \
int rows, \
int elems_per_row, \
cudaStream_t stream);
INSTANTIATE_LAUNCH_RMS_NORM(float)
INSTANTIATE_LAUNCH_RMS_NORM(__half)
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16)
#endif

Просмотреть файл

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "ds_kernel_utils.h"
template <typename T>
void launch_rms_norm(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* gamma,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
void rms_norm(torch::Tensor& norm_output,
torch::Tensor& norm_input,
torch::Tensor& gamma,
float epsilon);
void rms_pre_norm(torch::Tensor& norm_output,
torch::Tensor& residual_output,
torch::Tensor& norm_input,
torch::Tensor& residual_input,
torch::Tensor& gamma,
float epsilon);

Просмотреть файл

@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .rms_norm_base import CUDARMSNormBase
class CUDARMSNorm(CUDARMSNormBase):
"""
Floating point layer norm kernel for CUDA/RoCM.
Performs: z = ln(x)
"""
def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
"""
output_z may alias input_x directly. All Tensors should have the same shape.
Parameters:
output_z (torch.Tensor): Output tensor.
input_x (torch.Tensor): Input tensor.
gamma (torch.Tensor): Gamma tensor.
"""
self.inf_module.rms_norm(output_z, input_x, gamma, self.epsilon)
return output_z

Просмотреть файл

@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from ....inference_utils import elem_size
from deepspeed.ops.op_builder import InferenceCoreBuilder
class CUDARMSNormBase(DSKernelBase):
"""
Base class for CUDA LN kernels. They all same the same validation logic,
so we can share it here.
"""
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
"""
Parameters:
channels (int): Number of channels in the input tensor. Must be divisible to align
to 16 bytes.
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
are torch.float16, torch.bfloat16, and torch.float32.
"""
if fp_dtype not in CUDARMSNormBase.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, CUDARMSNormBase.supported_dtypes))
if elem_size(fp_dtype) * channels % 16 != 0:
raise ValueError("channels must be divisible by 16 bytes")
self.inf_module = InferenceCoreBuilder().load()
self.epsilon = epsilon

Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Tuple
import torch
from .rms_norm_base import CUDARMSNormBase
class CUDARMSPreNorm(CUDARMSNormBase):
"""
Floating point pre-LayerNorm kernel for CUDA/RoCM.
Performs: z_res = x_res + y_hid
z_hid = ln(z_hid)
"""
def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
z_res can alias x_res. All non-parameter input/output tensors
must have the same shape. z_hid can alias y_hid.
Parameters:
z_res (torch.Tensor): Output residual.
z_hid (torch.Tensor): Output hidden states.
x_res (torch.Tensor): Input residual.
y_hid (torch.Tensor): Input hidden states.
gamma (torch.Tensor): Gamma tensor.
beta (torch.Tensor): Beta tensor.
Returns:
output (torch.Tensor): Output tensor.
"""
self.inf_module.rms_pre_norm(z_hid, z_res, y_hid, x_res, gamma, self.epsilon)
return z_res, z_hid

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .gated_activation import *

Просмотреть файл

@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Optional
import torch
from ... import DSKernelBase
from ....inference_utils import ActivationType, elem_size
from deepspeed.ops.op_builder import InferenceCoreBuilder
class CUDAGatedActivation(DSKernelBase):
"""
CUDA implementation of gated activation kernel. This kernel assumes that the input
tensor has gate and activation values in adjacent channels. The output tensor should
have half the dimensionality of the input tensor.
"""
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
supported_act_fns = [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
def __init__(self, channels: int, fp_dtype: torch.dtype, act_fn: ActivationType) -> None:
"""
Compile and validate for the gated activation function.
Args:
channels (int): Number of columns in the output tensor. Must be divisible to align
to 8 bytes.
fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
are torch.float16, torch.bfloat16, and torch.float32.
act_fn (ActivationType): Activation function to use. Only GEGLU is supported.
"""
if fp_dtype not in CUDAGatedActivation.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, CUDAGatedActivation.supported_dtypes))
act_fn = ActivationType(act_fn)
if act_fn not in CUDAGatedActivation.supported_act_fns:
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
act_fn, CUDAGatedActivation.supported_act_fns))
if elem_size(fp_dtype) * channels % 8 != 0:
raise ValueError("Channels must be divisible by 16 bytes")
if elem_size(fp_dtype) * channels > 98304:
raise ValueError(
"Kernel only compiled to support 98304 bytes per row, please file an issue if your model requires more."
)
self.inf_module = InferenceCoreBuilder().load()
self.act_fn = act_fn
self.kernel = self.inf_module.gated_activation
def __call__(self, output: torch.Tensor, input: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
"""
Performs gated activation on the input tensor, writing the result to the output tensor.
Args:
output (torch.Tensor): Output tensor. Can be of [T, C // 2] or [B, S, C // 2]
input (torch.Tensor): Input tensor. Can be of [T, C] or [B, S, C]
"""
self.kernel(output, input, bias, self.act_fn.value)

Просмотреть файл

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "gated_activation_kernels.h"
#ifdef BF16_AVAILABLE
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using scalar_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kBFloat16) { \
using scalar_t = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#else
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using scalar_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using scalar_t = __half; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
} \
}()
#endif
void ds_gated_activation(at::Tensor& output,
at::Tensor& input,
c10::optional<torch::Tensor>& bias,
int activation_type_raw)
{
bool ragged_input = input.dim() == 2;
const ActivationType activation_type = static_cast<ActivationType>(activation_type_raw);
const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
const int cols = ragged_input ? input.size(1) : input.size(2);
DISPATCH_FOR_FLOAT(input.scalar_type(), [&] {
scalar_t* bias_ptr = nullptr;
if (bias.has_value()) {
TORCH_CHECK(bias.value().scalar_type() == input.scalar_type(),
"Bias type must match input type");
TORCH_CHECK(bias.value().numel() == cols,
"Bias must have the same number of elements as the input channels");
bias_ptr = reinterpret_cast<scalar_t*>(bias.value().data_ptr());
}
scalar_t* output_ptr = reinterpret_cast<scalar_t*>(output.data_ptr());
const scalar_t* input_ptr = reinterpret_cast<const scalar_t*>(input.data_ptr());
launch_gated_activation(output_ptr,
input_ptr,
bias_ptr,
rows,
cols,
activation_type,
c10::cuda::getCurrentCUDAStream());
});
}

Просмотреть файл

@ -0,0 +1,169 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <stdexcept>
#include "activation_type.h"
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
namespace gated_act {
constexpr int access_size = 16;
constexpr int threads = 1024;
template <ActivationType ActType>
float gated_act_fn(float x, float y);
template <>
DS_D_INLINE float gated_act_fn<ActivationType::GEGLU>(float x, float y)
{
constexpr float sqrt_param = 0.79788456080286535587989211986876f;
constexpr float mul_param = 0.044715;
return y * x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
template <>
DS_D_INLINE float gated_act_fn<ActivationType::ReGLU>(float x, float y)
{
return y * (x > 0.0f ? x : 0.0f);
}
template <>
DS_D_INLINE float gated_act_fn<ActivationType::SiGLU>(float x, float y)
{
return y * (x / (1.0f + expf(-x)));
}
} // namespace gated_act
template <typename T, ActivationType ActType, int loopUnroll>
__global__ void gated_activation_kernel(T* output,
const T* input,
const T* bias,
int rows,
int cols)
{
constexpr int read_vector = gated_act::access_size / sizeof(T);
constexpr int write_vector = read_vector / 2;
const int row = blockIdx.x;
const int col = threadIdx.x * read_vector;
const T* input_row = input + row * cols;
T* output_row = output + row * cols / 2;
#pragma unroll
for (int i = 0; i < loopUnroll; i++) {
T read[read_vector];
T bias_read[read_vector];
T store[write_vector];
const int read_offset = col + gated_act::threads * read_vector * i;
const int write_offset = col / 2 + gated_act::threads * write_vector * i;
if (i != loopUnroll - 1 || read_offset < cols) {
mem_access::load_global<gated_act::access_size>(read, input_row + read_offset);
mem_access::load_global<gated_act::access_size>(
bias_read, bias + read_offset, bias != nullptr);
for (int j = 0; j < write_vector; j++) {
float g_val =
conversion::to<float>(read[j * 2]) + conversion::to<float>(bias_read[j * 2]);
float a_val = conversion::to<float>(read[j * 2 + 1]) +
conversion::to<float>(bias_read[j * 2 + 1]);
float act_val = gated_act::gated_act_fn<ActType>(g_val, a_val);
store[j] = conversion::to<T>(act_val);
}
mem_access::store_global<gated_act::access_size / 2>(output_row + write_offset, store);
}
}
}
#define DISPATCH_UNROLL(unroll_val) \
gated_activation_kernel<T, ActType, unroll_val> \
<<<grid, block, 0, stream>>>(output, input, bias, rows, cols);
template <typename T, ActivationType ActType>
void launch_gated_activation_impl(T* output,
const T* input,
const T* bias,
int rows,
int cols,
cudaStream_t stream)
{
constexpr int read_vector = gated_act::access_size / sizeof(T);
constexpr int cols_per_unroll = gated_act::threads * read_vector;
const int req_threads = (cols + read_vector - 1) / read_vector;
const int threads = std::min(req_threads, gated_act::threads);
const dim3 grid(rows);
const dim3 block(threads);
const int unroll = (cols + cols_per_unroll - 1) / cols_per_unroll;
if (unroll == 1) {
DISPATCH_UNROLL(1);
} else if (unroll == 2) {
DISPATCH_UNROLL(2);
} else if (unroll == 3) {
DISPATCH_UNROLL(3);
} else if (unroll == 4) {
DISPATCH_UNROLL(4);
} else if (unroll == 5) {
DISPATCH_UNROLL(5);
} else if (unroll == 6) {
DISPATCH_UNROLL(6);
} else {
throw std::runtime_error(
"Called with more columns than supported, please report this bug and this limit will "
"be increased.");
}
}
template <typename T>
void launch_gated_activation(T* output,
const T* input,
const T* bias,
int rows,
int cols,
ActivationType act_type,
cudaStream_t stream)
{
switch (act_type) {
case ActivationType::GEGLU:
launch_gated_activation_impl<T, ActivationType::GEGLU>(
output, input, bias, rows, cols, stream);
break;
case ActivationType::ReGLU:
launch_gated_activation_impl<T, ActivationType::ReGLU>(
output, input, bias, rows, cols, stream);
break;
case ActivationType::SiGLU:
launch_gated_activation_impl<T, ActivationType::SiGLU>(
output, input, bias, rows, cols, stream);
break;
default: throw std::runtime_error("Unsupported activation type");
}
}
#define INSTANTIATE_FOR_TYPE(T) \
template void launch_gated_activation<T>(T * output, \
const T* input, \
const T* bias, \
int rows, \
int cols, \
ActivationType act_type, \
cudaStream_t stream);
INSTANTIATE_FOR_TYPE(float)
INSTANTIATE_FOR_TYPE(__half)
#ifdef BF16_AVAILABLE
INSTANTIATE_FOR_TYPE(__nv_bfloat16)
#endif

Просмотреть файл

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "activation_type.h"
#include "ds_kernel_utils.h"
template <typename T>
void launch_gated_activation(T* output,
const T* vals,
const T* bias,
int rows,
int cols,
ActivationType activation_type,
cudaStream_t stream);
void ds_gated_activation(at::Tensor& output,
at::Tensor& input,
c10::optional<torch::Tensor>& bias,
int activation_type_raw);

Просмотреть файл

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Просмотреть файл

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .mixed_gemm import *
from .moe_gemm import *

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h>
#include "mixed_gemm.h"
#include "moe_gemm.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// mixed_gemm.h
m.def("mixed_gemm", &mixed_gemm, "Mixed-precision GEMM");
// moe_gemm.h
m.def("moe_gemm", &moe_gemm, "MultiGEMM for MoE (16-bit weights)");
m.def("mixed_moe_gemm", &mixed_moe_gemm, "MultiGEMM for MoE (4-bit/8-bit weights)");
}

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .mixed_gemm import *

Просмотреть файл

@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h>
#include "mixed_gemm.h"
#include "mixed_gemm_api.h"
#include "weight_variant.h"
// Switch helpers inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#define ACT_DTYPE_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using ActivationDtype = __half; \
return __VA_ARGS__(); \
} else { \
using ActivationDtype = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
}()
#define WEIGHT_VARIANT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr WeightVariant WVariant = WeightVariant::kFP8; \
return __VA_ARGS__(); \
} else { \
constexpr WeightVariant WVariant = WeightVariant::kFP4; \
return __VA_ARGS__(); \
} \
}()
void mixed_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
at::Tensor& scales,
c10::optional<at::Tensor>& bias,
int num_bits,
int activation_raw)
{
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
"Output and hidden states must have the same dtype");
TORCH_CHECK(num_bits == 4 || num_bits == 8, "Data width must be 4 or 8");
TORCH_CHECK(output.size(0) == hidden_states.size(0), "Token dimension mismatch");
int32_t m = output.size(0);
int32_t k = hidden_states.size(1);
int32_t n = weight.size(1);
TORCH_CHECK(weight.size(0) == k, "Weight dimension mismatch");
ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] {
fastertransformer::CutlassFpAIntBGemmRunner<ActivationDtype, WVariant> runner =
*MixedGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
ActivationType activation_type = (ActivationType)activation_raw;
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
runner.gemm((ActivationDtype*)hidden_states.data_ptr(),
(const char*)weight.data_ptr(),
(ActivationDtype*)scales.data_ptr(),
(ActivationDtype*)output.data_ptr(),
m,
n,
k,
nullptr,
0,
at::cuda::getCurrentCUDAStream());
return;
} else {
ActivationDtype* bias_ptr = nullptr;
if (bias.has_value()) { bias_ptr = (ActivationDtype*)bias.value().data_ptr(); }
runner.gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
(char*)weight.data_ptr(),
(ActivationDtype*)scales.data_ptr(),
bias_ptr,
(ActivationDtype*)output.data_ptr(),
m,
n,
k,
activation_type,
nullptr,
0,
at::cuda::getCurrentCUDAStream());
return;
}
});
});
}

Просмотреть файл

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <torch/extension.h>
void mixed_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
at::Tensor& scales,
c10::optional<at::Tensor>& bias,
int num_bits,
int activation_raw);

Просмотреть файл

@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from ....inference_utils import ActivationType, DtypeEnum
from deepspeed.ops.op_builder import InferenceCutlassBuilder
from typing import Optional
class MixedGEMM(DSKernelBase):
"""
CUTLASS implementation of MoE GEMM.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None:
if not isinstance(fp_dtype, DtypeEnum):
fp_dtype = DtypeEnum(fp_dtype)
if fp_dtype not in MixedGEMM.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, MixedGEMM.supported_dtypes))
if act_fn not in MixedGEMM.supported_act_fns:
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
act_fn, MixedGEMM.supported_act_fns))
if num_bits != 4 and num_bits != 8:
raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits))
inf_module = InferenceCutlassBuilder().load()
self.num_bits = num_bits
self.kernel = inf_module.moe_gemm
self.act_fn = act_fn
def __call__(self,
output: torch.Tensor,
hidden_states: torch.Tensor,
weights: torch.Tensor,
scales: torch.Tensor,
biases: Optional[torch.Tensor] = None) -> None:
"""
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
Arguments:
output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
hidden_states (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
weights (torch.Tensor): The weights of shape [in_neurons, out_neurons]. These weights must be contiguous.
scales (torch.Tensor): The scales of shape [out_neurons]. These scales must be contiguous.
biases (torch.Tensor): The biases of shape [out_neurons]. These biases must be contiguous.
Returns:
output
"""
self.kernel(output, hidden_states, weights, biases, self.num_bits, self.act_fn)
return output

Просмотреть файл

@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "activation_type.h"
#include "weight_variant.h"
namespace fastertransformer {
template <typename T, WeightVariant V>
class CutlassFpAIntBGemmRunner {
public:
void gemm(const T* A,
const char* B,
const T* weight_scales,
T* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
void gemm_bias_act(const T* A,
const char* B,
const T* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
ActivationType activation_type,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
};
} // namespace fastertransformer
template <typename T, WeightVariant V>
class MixedGemmContext {
public:
MixedGemmContext() { _runner = new fastertransformer::CutlassFpAIntBGemmRunner<T, V>(); }
virtual ~MixedGemmContext() { delete _runner; }
static MixedGemmContext& Instance()
{
static MixedGemmContext _ctx;
return _ctx;
}
fastertransformer::CutlassFpAIntBGemmRunner<T, V>* GeMM_Runner() const { return _runner; }
fastertransformer::CutlassFpAIntBGemmRunner<T, V>* _runner;
};

Просмотреть файл

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .mixed_moe_gemm import *
from .moe_gemm import *

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from ....inference_utils import ActivationType, DtypeEnum
from deepspeed.ops.op_builder import InferenceCutlassBuilder
from typing import Optional
class MixedMoEGEMM(DSKernelBase):
"""
CUTLASS implementation of MoE GEMM.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None:
if not isinstance(fp_dtype, DtypeEnum):
fp_dtype = DtypeEnum(fp_dtype)
if fp_dtype not in MixedMoEGEMM.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, MixedMoEGEMM.supported_dtypes))
if act_fn not in MixedMoEGEMM.supported_act_fns:
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
act_fn, MixedMoEGEMM.supported_act_fns))
if num_bits != 4 and num_bits != 8:
raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits))
inf_module = InferenceCutlassBuilder().load()
self.num_bits = num_bits
self.kernel = inf_module.moe_gemm
self.act_fn = act_fn
def __call__(self,
ordered_output: torch.Tensor,
ordered_input: torch.Tensor,
weights: torch.Tensor,
scales: torch.Tensor,
total_rows_before_expert: torch.Tensor,
biases: Optional[torch.Tensor] = None) -> None:
"""
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
Arguments:
ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous.
scales (torch.Tensor): The scales of shape [n_experts, out_neurons]. These scales must be contiguous.
total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts].
biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous.
Returns:
ordered_output
"""
self.kernel(ordered_output, ordered_input, weights, scales, biases, total_rows_before_expert, self.num_bits,
self.act_fn)
return ordered_output

Просмотреть файл

@ -0,0 +1,175 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h>
#include "moe_gemm.h"
#include "moe_gemm_api.h"
#include "weight_variant.h"
// Switch helpers inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#define HIDDEN_DTYPE_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using ActivationDtype = __half; \
constexpr WeightVariant WVariant = WeightVariant::kFP16; \
return __VA_ARGS__(); \
} else { \
using ActivationDtype = __nv_bfloat16; \
constexpr WeightVariant WVariant = WeightVariant::kBF16; \
return __VA_ARGS__(); \
} \
}()
void moe_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
c10::optional<at::Tensor>& bias,
at::Tensor& total_rows_before_expert,
int activation_raw)
{
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
"Output and hidden states must have the same dtype");
TORCH_CHECK(output.dtype() == weight.dtype(), "Output and weight must have the same dtype");
int64_t total_rows = hidden_states.size(0);
int64_t gemm_k = hidden_states.size(1);
int64_t gemm_n = weight.size(2);
int num_experts = weight.size(0);
TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch");
TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch");
TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch");
TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch");
HIDDEN_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
fastertransformer::MoeGemmRunner<ActivationDtype, WVariant> runner =
*MoeGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
ActivationType activation_type = (ActivationType)activation_raw;
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(),
(char*)weight.data_ptr(),
nullptr,
(ActivationDtype*)output.data_ptr(),
(int64_t*)total_rows_before_expert.data_ptr(),
total_rows,
gemm_n,
gemm_k,
num_experts,
at::cuda::getCurrentCUDAStream());
return;
} else {
ActivationDtype* bias_ptr = nullptr;
if (bias.has_value()) {
bias_ptr = (ActivationDtype*)bias.value().data_ptr();
TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch");
TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch");
}
runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
(char*)weight.data_ptr(),
nullptr,
bias_ptr,
(ActivationDtype*)output.data_ptr(),
(int64_t*)total_rows_before_expert.data_ptr(),
total_rows,
gemm_n,
gemm_k,
num_experts,
activation_type,
at::cuda::getCurrentCUDAStream());
return;
}
});
}
#define ACT_DTYPE_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using ActivationDtype = __half; \
return __VA_ARGS__(); \
} else { \
using ActivationDtype = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
}()
#define WEIGHT_VARIANT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr WeightVariant WVariant = WeightVariant::kFP8; \
return __VA_ARGS__(); \
} else { \
constexpr WeightVariant WVariant = WeightVariant::kFP4; \
return __VA_ARGS__(); \
} \
}()
void mixed_moe_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
at::Tensor& scales,
c10::optional<at::Tensor>& bias,
at::Tensor& total_rows_before_expert,
int num_bits,
int activation_raw)
{
TORCH_CHECK(output.dtype() == hidden_states.dtype(),
"Output and hidden states must have the same dtype");
int64_t total_rows = hidden_states.size(0);
int64_t gemm_k = hidden_states.size(1);
int64_t gemm_n = weight.size(2);
int num_experts = weight.size(0);
TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch");
TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch");
TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch");
TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch");
ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] {
WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] {
fastertransformer::MoeGemmRunner<ActivationDtype, WVariant> runner =
*MoeGemmContext<ActivationDtype, WVariant>::Instance().GeMM_Runner();
ActivationType activation_type = (ActivationType)activation_raw;
if (!bias.has_value() && activation_type == ActivationType::IDENTITY) {
runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(),
(char*)weight.data_ptr(),
(ActivationDtype*)scales.data_ptr(),
(ActivationDtype*)output.data_ptr(),
(int64_t*)total_rows_before_expert.data_ptr(),
total_rows,
gemm_n,
gemm_k,
num_experts,
at::cuda::getCurrentCUDAStream());
return;
} else {
ActivationDtype* bias_ptr = nullptr;
if (bias.has_value()) {
bias_ptr = (ActivationDtype*)bias.value().data_ptr();
TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch");
TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch");
}
runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(),
(char*)weight.data_ptr(),
(ActivationDtype*)scales.data_ptr(),
bias_ptr,
(ActivationDtype*)output.data_ptr(),
(int64_t*)total_rows_before_expert.data_ptr(),
total_rows,
gemm_n,
gemm_k,
num_experts,
activation_type,
at::cuda::getCurrentCUDAStream());
return;
}
});
});
}

Просмотреть файл

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <torch/extension.h>
void moe_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
c10::optional<at::Tensor>& bias,
at::Tensor& total_rows_before_expert,
int activation_raw);
void mixed_moe_gemm(at::Tensor& output,
at::Tensor& hidden_states,
at::Tensor& weight,
at::Tensor& scales,
c10::optional<at::Tensor>& bias,
at::Tensor& total_rows_before_expert,
int num_bits,
int activation_raw);

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from ....inference_utils import ActivationType, DtypeEnum
from deepspeed.ops.op_builder import InferenceCutlassBuilder
from typing import Optional
class MoEGEMM(DSKernelBase):
"""
CUTLASS implementation of MoE GEMM.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY]
def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType) -> None:
if not isinstance(fp_dtype, DtypeEnum):
fp_dtype = DtypeEnum(fp_dtype)
if fp_dtype not in MoEGEMM.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, MoEGEMM.supported_dtypes))
if act_fn not in MoEGEMM.supported_act_fns:
raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
act_fn, MoEGEMM.supported_act_fns))
inf_module = InferenceCutlassBuilder().load()
self.kernel = inf_module.moe_gemm
self.act_fn = act_fn
def __call__(self,
ordered_output: torch.Tensor,
ordered_input: torch.Tensor,
weights: torch.Tensor,
total_rows_before_expert: torch.Tensor,
biases: Optional[torch.Tensor] = None) -> None:
"""
Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2).
Arguments:
ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons].
ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons].
weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous.
total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts].
biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous.
Returns:
ordered_output
"""
self.kernel(ordered_output, ordered_input, weights, biases, total_rows_before_expert, self.act_fn)
return ordered_output

Просмотреть файл

@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "activation_type.h"
#include "weight_variant.h"
namespace fastertransformer {
template <typename T, /*The type used for activations/scales/compute*/
WeightVariant V /* The type for the MoE weights */>
class MoeGemmRunner {
public:
MoeGemmRunner();
void moe_gemm_bias_act(const T* A,
const char* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
ActivationType activation_type,
cudaStream_t stream);
void moe_gemm(const T* A,
const char* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
cudaStream_t stream);
private:
int sm_;
int multi_processor_count_;
};
} // namespace fastertransformer
template <typename T, WeightVariant V>
class MoeGemmContext {
public:
MoeGemmContext() { _runner = new fastertransformer::MoeGemmRunner<T, V>(); }
virtual ~MoeGemmContext() { delete _runner; }
static MoeGemmContext& Instance()
{
static MoeGemmContext _ctx;
return _ctx;
}
fastertransformer::MoeGemmRunner<T, V>* GeMM_Runner() const { return _runner; }
fastertransformer::MoeGemmRunner<T, V>* _runner;
};

Просмотреть файл

@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
// Data structure that allows us to abstract internal CUTLASS datatypes/mappings
// to the DeepSpeed-Kernels repo.
#pragma once
enum WeightVariant { kFP16, kBF16, kFP8, kFP4 };

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from abc import ABC, abstractmethod
class DSKernelBase(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
"""
If necessary trigger compilation and warmup
Autotuning of the kernel would happen at this stage to
eliminate any potential hangs that might occur mid-deployment
Validate that the desired run configuration is compatible.
It is not necessary to call super on this method.
"""
raise NotImplementedError()
@abstractmethod
def __call__(self, *args, **kwargs):
"""
However the kernel needs to be called, it can be called here. Auto-tuning
should never be performed here.
All inputs/outputs should be passed as arguments to this function. No allocations
should be performed here.
"""
raise NotImplementedError()

Просмотреть файл

@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .atom_builder import *
from .blocked_flash import *
from .embed import *
from .linear_blocked_kv_rotary import *
from .logits_gather import *
from .moe_gather import *
from .moe_scatter import *
from .top_1_gating import *

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .atom_builder import *

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "atom_builder.h"
#include "attention_atom.h"
#include "ragged_dtypes.h"
int32_t build_atoms(torch::Tensor& atoms_ten,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& kv_ptrs,
const int32_t q_block_size,
const int32_t kv_block_size)
{
const RaggedBatchDescriptor* batch_desc =
reinterpret_cast<const RaggedBatchDescriptor*>(batch_metadata.data_ptr());
const InflightSeqDescriptor* seq_desc =
reinterpret_cast<const InflightSeqDescriptor*>(seq_metadata.data_ptr());
int32_t** kv_ptr_list = reinterpret_cast<int32_t**>(kv_ptrs.data_ptr());
AttentionAtom* atoms = reinterpret_cast<AttentionAtom*>(atoms_ten.data_ptr());
int32_t n_atoms = 0;
for (int i = 0; i < batch_desc->n_sequences; i++) {
const int seq_atoms = (seq_desc[i].n_tokens + q_block_size - 1) / q_block_size;
int32_t cur_start_idx = seq_desc[i].start_idx;
int32_t global_start_idx = seq_desc[i].seen_tokens;
int32_t remaining_toks = seq_desc[i].n_tokens;
for (int j = 0; j < seq_atoms; j++) {
atoms[n_atoms].block_idx_list = kv_ptr_list[i];
atoms[n_atoms].q_start_idx = cur_start_idx;
atoms[n_atoms].q_len = std::min(remaining_toks, q_block_size);
atoms[n_atoms].global_q_idx = global_start_idx;
const int32_t end_toks = global_start_idx + atoms[n_atoms].q_len;
// TODO(cmikeh2): This logic needs to be changed for sparse implementations
atoms[n_atoms].kv_blocks = (end_toks + kv_block_size - 1) / kv_block_size;
atoms[n_atoms].total_extent = end_toks;
cur_start_idx += atoms[n_atoms].q_len;
global_start_idx += atoms[n_atoms].q_len;
remaining_toks -= atoms[n_atoms].q_len;
n_atoms++;
}
}
return n_atoms;
}

Просмотреть файл

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <torch/extension.h>
/*
Construct the attention atoms given the ragged metadata for the current batch.
This could largely be done at the Python level, but since we pack the KV ptr
alongside the int32_t metadata, it gets very ugly to handle the mixed-width
data structures (since we're packing them in a single tensor).
*/
int32_t build_atoms(torch::Tensor& atoms_ten,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& kv_ptrs,
const int32_t q_block_size,
const int32_t kv_block_size);

Просмотреть файл

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Tuple
import torch
from ... import DSKernelBase
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ....ragged import RaggedBatchWrapper
class AtomBuilder(DSKernelBase):
"""
C++ implementation to populate the attention atoms for the blocked attention
kernel.
"""
def __init__(self) -> None:
"""
Triggers compilation of the C++ implementation.
"""
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.build_atoms
def __call__(self, atoms: torch.Tensor, ragged_batch: RaggedBatchWrapper, q_block_size: int,
kv_block_size: int) -> Tuple[torch.Tensor, int]:
"""
Populates the attention atoms for the blocked attention kernel.
Args:
atoms (torch.Tensor): Pre-allocated int32 tensor of shape [max_atoms, 8]
ragged_batch (torch.Tensor): Wrapper for the ragged batch.
q_block_size (int): The block size for the queries (as determined by the
attention implementation)
kv_block_size (int): The block size for the keys/values (as determined by the
attention implementation)
Returns:
"""
if atoms.device != torch.device("cpu"):
raise RuntimeError("AtomBuilder must be called on tensors")
n_atoms = self.kernel(atoms, ragged_batch.batch_metadata_buffer(on_device=False),
ragged_batch.inflight_seq_descriptors(on_device=False),
ragged_batch.kv_ptrs(on_device=False), q_block_size, kv_block_size)
return atoms, n_atoms

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .blocked_flash import *

Просмотреть файл

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <cstdint>
#include "cuda.h"
struct AttentionAtom {
/*
The attention atom describes the workload of a particular query. The attention
kernel will execute each ``AttentionAtom`` for each head of the model.
*/
// Pointer to a list of KV block indices.
int32_t* block_idx_list;
// Index of first token in the ragged batch associated with this atom.
int32_t q_start_idx;
// Number of tokens in the ragged batch associated with this atom.
int32_t q_len;
// Number of key/value blocks associated with this atom. All but the last are
// assumed to be fully dense.
int32_t kv_blocks;
// Number of tokens in the last key/value block.
int32_t total_extent;
// Global index of the first token in the atom. For example, in a prompt continuation
// in which we have already processed 768 tokens, this would be 768.
int32_t global_q_idx;
// Unused
int32_t unused;
};

Просмотреть файл

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include "blocked_flash.h"
#include "flash.h"
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
void flash_attn_by_atoms(at::Tensor& out,
at::Tensor& q,
at::Tensor& k,
at::Tensor& v,
at::Tensor& attention_atoms,
const float softmax_scale,
const bool is_causal)
{
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const int total_q = q.size(0);
const int head_size = k.size(-1);
const int num_heads_kv = k.size(-2);
const int num_heads_q = q.size(-1) / head_size;
TORCH_CHECK(head_size <= 256, "head_size must be <= 256");
TORCH_CHECK(head_size % 8 == 0, "head_size must be divisible by 8");
TORCH_CHECK(num_heads_q % num_heads_kv == 0, "num_heads_q must be divisible by num_heads_kv");
Flash_fwd_params params;
params.is_bf16 = q.dtype() == torch::kBFloat16;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.o_ptr = out.data_ptr();
params.atoms = reinterpret_cast<AttentionAtom*>(attention_atoms.data_ptr());
// All stride are in elements, not bytes.
params.q_row_stride = q.stride(0);
params.k_row_stride = k.stride(1);
params.v_row_stride = v.stride(1);
params.o_row_stride = out.stride(0);
// Assume heads are contiguous.
params.q_head_stride = head_size;
params.k_head_stride = head_size;
params.v_head_stride = head_size;
params.o_head_stride = head_size;
// Head params
params.h = num_heads_q;
params.h_k = num_heads_kv;
params.h_h_k_ratio = num_heads_q / num_heads_kv;
params.d = head_size;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
params.d_rounded = round_multiple(head_size, 32);
params.num_atoms = attention_atoms.size(0);
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.is_causal = is_causal;
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
}

Просмотреть файл

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <torch/extension.h>
void flash_attn_by_atoms(at::Tensor& out,
at::Tensor& q,
at::Tensor& k,
at::Tensor& v,
at::Tensor& attention_atoms,
const float softmax_scale,
const bool is_causal);

Просмотреть файл

@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.accelerator import get_accelerator
from ....inference_utils import DtypeEnum
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ... import DSKernelBase
def get_q_block_size(head_size: int) -> int:
"""
Returns the query block size required by the kernel given a head size.
"""
cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda
if cc_major < 8:
raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0")
if head_size <= 64:
return 128
elif head_size <= 160:
if cc_minor != 0:
return 64
else:
return 128
elif head_size == 192:
return 128
elif head_size == 224:
if cc_minor != 0:
return 64
else:
return 128
else:
if cc_major == 8 and cc_minor == 0:
return 128
else:
return 64
def get_kv_block_size(head_size: int) -> int:
"""
Return preferred granulatity for blocked KV-cache implementation.
"""
cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda
if cc_major < 8:
raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0")
if head_size <= 64:
return 128
elif head_size != 160 or cc_minor != 0:
return 64
else:
return 32
class BlockedFlashAttn(DSKernelBase):
"""
Modified implementation of flash-attn-2 tuned for inference on blocked KV-cache and wider
range of input sequence lengths.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
def __init__(self, head_size: int, dtype: DtypeEnum) -> None:
"""
Triggers any compilation of the kernels.
"""
if not isinstance(dtype, DtypeEnum):
dtype = DtypeEnum(dtype)
if dtype not in BlockedFlashAttn.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported data types are {}".format(
dtype, BlockedFlashAttn.supported_dtypes))
# For testing, need to revert to 32
if head_size % 16 != 0:
raise ValueError("Head size must be divisible by 32 (configured with {})".format(head_size))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.flash_attn_by_atoms
def __call__(self, out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, atoms: torch.Tensor,
softmax_scale: float) -> torch.Tensor:
"""
Flash attention implementation atop a blocked KV-cache. Atoms should be pre-populated.
See attention_atom.h for further details on the structure of the information.
Arguments:
out (torch.Tensor): Output tensor of shape [tokens, hidden_size]
q (torch.Tensor): Query tensor of shape [tokens, hidden_size]
k (torch.Tensor): Key cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension.
v (torch.Tensor): Value cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension.
atoms (torch.Tensor): Atom information tensor of shape [num_atoms, 8] and type int32.
Not all data is readable in this format. See attention_atom.h for further details.
softmax_scale (float): Softmax scale factor.
Returns:
out (torch.Tensor): Output tensor of shape [tokens, hidden_size]
"""
self.kernel(out, q, k, v, atoms, softmax_scale, True)
return out

Просмотреть файл

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/******************************************************************************
Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#include "attention_atom.h"
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
// The attention metadata
AttentionAtom* __restrict__ atoms;
// Total attention atoms
int num_atoms;
// The stride between rows of O.
index_t o_row_stride;
index_t o_head_stride;
// The dimensions
int d, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream);

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .embed import RaggedEmbeddingKernel

Просмотреть файл

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "embed.h"
#include "ragged_kernel_helpers.h"
#ifdef BF16_AVAILABLE
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using float_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using float_t = __half; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kBFloat16) { \
using float_t = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dispatch type"); \
} \
}()
#else
#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kFloat32) { \
using float_t = float; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kFloat16) { \
using float_t = __half; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dispatch type"); \
} \
}()
#endif
#define DISPATCH_FOR_INT(DTYPE, ...) \
[&] { \
if (DTYPE == torch::kInt32) { \
using int_t = int32_t; \
return __VA_ARGS__(); \
} else if (DTYPE == torch::kInt64) { \
using int_t = int64_t; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported dispatch type"); \
} \
}()
/*
Embeddings kernel aware of ragged batch structure.
*/
void ragged_embed(torch::Tensor& embedded_tokens,
torch::Tensor& input_ids,
torch::Tensor& embedding_weight,
c10::optional<torch::Tensor>& position_embedding_weight,
int32_t pos_embed_offset,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs)
{
// We don't care about KV cache here, so just hardcoding 0s for block_size/num_blocks
BatchWrapperCPP batch_wrapper =
make_cpp_batch_wrapper(batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, 0, 0);
const int32_t n_tokens = input_ids.numel();
const int32_t embed_dim = embedding_weight.size(1);
const int32_t vocab_size = embedding_weight.size(0);
DISPATCH_FOR_INT(input_ids.scalar_type(), [&] {
DISPATCH_FOR_FLOAT(embedding_weight.scalar_type(), [&] {
float_t* pos_embed_ptr = nullptr;
int32_t max_position_embed_idx = 0;
if (position_embedding_weight.has_value()) {
TORCH_CHECK(
position_embedding_weight.value().options().dtype() ==
embedding_weight.options().dtype(),
"position_embedding_weight and embedding_weight must have the same dtype");
pos_embed_ptr =
reinterpret_cast<float_t*>(position_embedding_weight.value().data_ptr());
max_position_embed_idx = position_embedding_weight.value().size(0) - 1;
}
launch_ragged_embed_kernel((float_t*)embedded_tokens.data_ptr(),
(const int_t*)input_ids.data_ptr(),
(const float_t*)embedding_weight.data_ptr(),
pos_embed_ptr,
batch_wrapper,
n_tokens,
embed_dim,
vocab_size,
max_position_embed_idx,
pos_embed_offset,
at::cuda::getCurrentCUDAStream());
});
});
}

Просмотреть файл

@ -0,0 +1,137 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "ds_kernel_utils.h"
#include "embed.cuh"
#include "memory_access_utils.h"
#include "ragged_dtypes.h"
namespace embed {
constexpr int granularity = 16;
constexpr int threads = 512;
} // namespace embed
template <typename TokenType, typename EmbedType>
__global__ void ragged_embed_kernel(EmbedType* embedded_tokens,
const TokenType* input_ids,
const EmbedType* embedding_weight,
const EmbedType* position_weight,
const BatchWrapperCPP batch_desc,
const int32_t embed_dim,
const int32_t vocab_size,
const int32_t max_position_embed_idx,
const int32_t position_embed_offset)
{
constexpr int T_vector = embed::granularity / sizeof(EmbedType);
const int32_t token_idx = blockIdx.y;
// It's possible our batch is padded (under CG conditions typically)
if (token_idx >= batch_desc.batch_metadata->n_tokens) return;
TokenType token_value = input_ids[token_idx];
if (token_value >= vocab_size || token_value < 0) {
// TODO(cmikeh2): This is invalid, but not sure how we want to handle it being invalid
// yet.
return;
}
const EmbedType* embedding_row = embedding_weight + token_value * embed_dim;
EmbedType* dest_row = embedded_tokens + token_idx * embed_dim;
const int channel_offset = (threadIdx.x + embed::threads * blockIdx.x) * T_vector;
if (channel_offset < embed_dim) {
EmbedType reg_buf[T_vector];
mem_access::load_global<embed::granularity>(reg_buf, embedding_row + channel_offset);
if (position_weight != nullptr) {
// Map the token to its global idx (indirect memory accesses aren't great but whatever)
const int32_t seq_idx = batch_desc.tokens_to_seq[token_idx];
const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_idx];
int32_t pos_emb_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx);
// Position embed offset is an OPT-specific feature I think?
pos_emb_idx = pos_emb_idx + position_embed_offset;
// This clamping is technically
pos_emb_idx = (pos_emb_idx < 0) ? 0 : pos_emb_idx;
pos_emb_idx = (pos_emb_idx >= max_position_embed_idx) ? max_position_embed_idx
: pos_emb_idx;
const EmbedType* position_embedding_row = position_weight + pos_emb_idx * embed_dim;
EmbedType pos_buf[T_vector];
mem_access::load_global<embed::granularity>(pos_buf,
position_embedding_row + channel_offset);
#pragma unroll
for (int i = 0; i < T_vector; i++) { reg_buf[i] += pos_buf[i]; }
}
mem_access::store_global<embed::granularity>(dest_row + channel_offset, reg_buf);
}
}
template <typename TokenType, typename EmbedType>
void launch_ragged_embed_kernel(EmbedType* embedded_tokens,
const TokenType* input_ids,
const EmbedType* embedding_weight,
const EmbedType* position_weight,
const BatchWrapperCPP batch_desc,
const int32_t n_tokens,
const int32_t embed_dim,
const int32_t vocab_size,
const int32_t max_position_embed_idx,
const int32_t position_embed_offset,
cudaStream_t stream)
{
constexpr int T_vector = embed::granularity / sizeof(EmbedType);
constexpr int elems_per_block = embed::threads * T_vector;
const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block;
const dim3 grid_dim(parallel_blocks, n_tokens, 1);
const dim3 block_dim(embed::threads, 1, 1);
ragged_embed_kernel<TokenType, EmbedType>
<<<grid_dim, block_dim, 0, stream>>>(embedded_tokens,
input_ids,
embedding_weight,
position_weight,
batch_desc,
embed_dim,
vocab_size,
max_position_embed_idx,
position_embed_offset);
}
#define INSTANTIATE_EMBED_FOR_TYPES(TOKEN_TYPE, EMBED_TYPE) \
template void launch_ragged_embed_kernel<TOKEN_TYPE, EMBED_TYPE>( \
EMBED_TYPE * embedded_tokens, \
const TOKEN_TYPE* input_ids, \
const EMBED_TYPE* embedding_weight, \
const EMBED_TYPE* position_weight, \
const BatchWrapperCPP batch_descriptor, \
const int32_t n_tokens, \
const int32_t embed_dim, \
const int32_t vocab_size, \
const int32_t max_position_embed_idx, \
const int32_t position_embed_offset, \
cudaStream_t stream);
INSTANTIATE_EMBED_FOR_TYPES(int32_t, float)
INSTANTIATE_EMBED_FOR_TYPES(int64_t, float)
INSTANTIATE_EMBED_FOR_TYPES(int32_t, __half)
INSTANTIATE_EMBED_FOR_TYPES(int64_t, __half)
#ifdef BF16_AVAILABLE
INSTANTIATE_EMBED_FOR_TYPES(int32_t, __nv_bfloat16)
INSTANTIATE_EMBED_FOR_TYPES(int64_t, __nv_bfloat16)
#endif

Просмотреть файл

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "ds_kernel_utils.h"
#include "ragged_dtypes.h"
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
template <typename TokenType, typename EmbedType>
void launch_ragged_embed_kernel(EmbedType* embedded_tokens,
const TokenType* input_ids,
const EmbedType* embedding_weight,
const EmbedType* position_weight,
const BatchWrapperCPP batch_desc,
const int32_t n_tokens,
const int32_t embed_dim,
const int32_t vocab_size,
const int32_t max_position_embed_idx,
const int32_t position_embed_offset,
cudaStream_t stream);

Просмотреть файл

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "embed.cuh"
/*
Embeddings kernel aware of ragged batch structure.
*/
void ragged_embed(torch::Tensor& embedded_tokens,
torch::Tensor& input_ids,
torch::Tensor& embedding_weight,
c10::optional<torch::Tensor>& position_weight,
int32_t position_embed_offset,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs);

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Optional
import torch
from ... import DSKernelBase
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ....inference_utils import elem_size
from ....ragged import RaggedBatchWrapper
class RaggedEmbeddingKernel(DSKernelBase):
"""
Ragged-aware CUDA kernel implementation for an embedding lookup. This will only lookup
the necessary tokens for a padded batch (i.e. if we are CGed and running with a slightly
larger batch size than the actual tokens).
"""
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
supported_token_dtypes = [torch.int32, torch.int64]
def __init__(self, embed_dtype: torch.dtype, token_dtype: torch.dtype, embed_dim: int) -> None:
"""
Args:
fp_dtype (torch.dtype): Data type of the embedding table and output dtype.
Supported values are torch.float16, torch.bfloat16, and torch.float32.
token_dtype (torch.dtype): Data type of the token ids. Supported values are
torch.int32 and torch.int64.
embed_dim (int): Embedding dimension. Must be aligned to 16 bytes.
"""
if embed_dtype not in RaggedEmbeddingKernel.supported_dtypes:
raise ValueError("Unsupported embedding data type: {}, supported_dtypes are {}".format(
embed_dtype, RaggedEmbeddingKernel.supported_dtypes))
if token_dtype not in RaggedEmbeddingKernel.supported_token_dtypes:
raise ValueError("Unsupported token data type: {}, supported_dtypes are {}".format(
token_dtype, RaggedEmbeddingKernel.supported_token_dtypes))
if elem_size(embed_dtype) * embed_dim % 16 != 0:
raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(embed_dim))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.ragged_embed
def __call__(self,
embedded_tokens: torch.Tensor,
ragged_wrapper: RaggedBatchWrapper,
embedding_weight: torch.Tensor,
position_embed_weight: Optional[torch.Tensor] = None,
position_embed_offset: int = 0) -> torch.Tensor:
"""
Ragged aware embedding lookup.
Args:
embedded_tokens (torch.Tensor): Output tensor of shape [num_tokens, embed_dim]
ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch.
embedding_weight (torch.Tensor): Embedding table of shape [vocab_size, embed_dim]
"""
self.kernel(embedded_tokens, ragged_wrapper.input_ids(),
embedding_weight, position_embed_weight, position_embed_offset,
ragged_wrapper.batch_metadata_buffer(), ragged_wrapper.inflight_seq_descriptors(),
ragged_wrapper.tokens_to_seq(), ragged_wrapper.kv_ptrs())
return embedded_tokens

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .blocked_kv_rotary import *
from .blocked_trained_kv_rotary import *
from .linear_blocked_kv_copy import *

Просмотреть файл

@ -0,0 +1,188 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "blocked_kv_rotary.h"
#include "ragged_kernel_helpers.h"
#define DISPATCH_KV_ROTARY(T_TYPE, C_TYPE) \
if (q.options().dtype() == torch::T_TYPE) { \
launch_kv_rotary_kernel<C_TYPE>((C_TYPE*)kv_cache.data_ptr(), \
(C_TYPE*)q.data_ptr(), \
(C_TYPE*)k.data_ptr(), \
(C_TYPE*)v.data_ptr(), \
(C_TYPE*)inv_freq_ptr, \
batch_wrapper, \
qkv_stride, \
kv_cache_stride, \
v_offset, \
inv_freq_stride, \
q_ratio, \
head_size, \
n_tokens, \
n_q_heads, \
at::cuda::getCurrentCUDAStream()); \
}
/*
Rotary position embeddings + copy into KV cache. This implementation assumes
that the inverse frequencies should be ready from global memory rather than
synthesized in the kernel.
Arguments:
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
q: [n_tokens, n_q_heads * head_size]
k: [n_tokens, n_kv_heads * head_size]
v: [n_tokens, n_kv_heads * head_size]
inv_freq: [max_seq_len, head_size // 2]
*/
void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& inv_freq,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs)
{
const int32_t n_tokens = q.size(0);
TORCH_CHECK(n_tokens == k.size(0));
TORCH_CHECK(n_tokens == v.size(0));
// Dimensions
const int32_t block_size = kv_cache.size(1);
const int32_t n_kv_heads = kv_cache.size(3);
const int32_t head_size = kv_cache.size(4);
// Strides
const int32_t qkv_stride = q.stride(0); // Per token
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
const int32_t inv_freq_stride = inv_freq.stride(0); // Per token idx
const int n_q_heads = q.size(1) / head_size;
const int q_ratio = n_q_heads / n_kv_heads;
void* inv_freq_ptr = (void*)inv_freq.data_ptr();
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
DISPATCH_KV_ROTARY(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16);
#endif
}
/*
Rotary position embeddings + copy into KV cache. This implementation assumes
that the inverse frequencies should be synthesized in the kernel.
Arguments:
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
q: [n_tokens, n_q_heads * head_size]
k: [n_tokens, n_kv_heads * head_size]
v: [n_tokens, n_kv_heads * head_size]
*/
void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs)
{
const int32_t n_tokens = q.size(0);
TORCH_CHECK(n_tokens == k.size(0));
TORCH_CHECK(n_tokens == v.size(0));
// Dimensions
const int32_t block_size = kv_cache.size(1);
const int32_t n_kv_heads = kv_cache.size(3);
const int32_t head_size = kv_cache.size(4);
// Strides
const int32_t qkv_stride = q.stride(0); // Per token
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
const int32_t inv_freq_stride = 0; // Per token idx
const int n_q_heads = q.size(1) / head_size;
const int q_ratio = n_q_heads / n_kv_heads;
void* inv_freq_ptr = nullptr;
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
DISPATCH_KV_ROTARY(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16);
#endif
}
#define DISPATCH_KV_COPY(T_TYPE, C_TYPE) \
if (q.options().dtype() == torch::T_TYPE) { \
launch_kv_copy_kernel<C_TYPE>((C_TYPE*)kv_cache.data_ptr(), \
(C_TYPE*)q.data_ptr(), \
(C_TYPE*)k.data_ptr(), \
(C_TYPE*)v.data_ptr(), \
batch_wrapper, \
qkv_stride, \
kv_cache_stride, \
v_offset, \
q_ratio, \
head_size, \
n_tokens, \
n_q_heads, \
at::cuda::getCurrentCUDAStream()); \
}
/*
Copy into linear KV cache.
*/
void linear_kv_copy(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs)
{
const int32_t n_tokens = q.size(0);
TORCH_CHECK(n_tokens == k.size(0));
TORCH_CHECK(n_tokens == v.size(0));
// Dimensions
const int32_t block_size = kv_cache.size(1);
const int32_t n_kv_heads = kv_cache.size(3);
const int32_t head_size = kv_cache.size(4);
// Strides
const int32_t qkv_stride = q.stride(0); // Per token
TORCH_CHECK(qkv_stride == k.stride(0));
TORCH_CHECK(qkv_stride == v.stride(0));
const int32_t kv_cache_stride = kv_cache.stride(1); // Per token
const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache
const int n_q_heads = q.size(1) / head_size;
TORCH_CHECK(n_q_heads % n_kv_heads == 0);
const int q_ratio = n_q_heads / n_kv_heads;
BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper(
batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0));
DISPATCH_KV_COPY(kHalf, __half);
#ifdef BF16_AVAILABLE
DISPATCH_KV_COPY(kBFloat16, __nv_bfloat16);
#endif
}

Просмотреть файл

@ -0,0 +1,314 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "blocked_kv_rotary.cuh"
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
namespace kv_rot {
constexpr int granularity = 16;
constexpr int threads = 256;
} // namespace kv_rot
/*
Supports head size 32, 64, 128, 256
*/
template <typename T, int qRatio, int headSize, bool doRotary>
__global__ void kv_rotary_pos_kernel(T* kv_cache,
T* q,
T* k,
T* v,
const T* inv_freq,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
const int v_offset,
const int inv_freq_stride)
{
// Derived constexpr
constexpr int vector_T = kv_rot::granularity / sizeof(T);
constexpr int threads_per_head = headSize / vector_T;
constexpr int half_head_size = headSize >> 1;
constexpr int tokens_per_block = kv_rot::threads / threads_per_head;
// CG helpers
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
cg::thread_block_tile<threads_per_head> head_group =
cg::tiled_partition<threads_per_head>(warp);
// Parallelize on the head dimension for X blocks
const int head_idx = blockIdx.x;
const int block_seq_idx = threadIdx.x / threads_per_head;
const int base_neuron_idx = (threadIdx.x * vector_T) % headSize;
const int half_idx = base_neuron_idx % half_head_size;
const int half_head_lanes = threads_per_head / 2;
// Multiple tokens processed by the same threadblock
const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx;
const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens;
const bool load_inv_freq = (inv_freq != nullptr) && valid_token;
// If we have GQA, then only one of the Q heads needs to do rotary + copy
// for each of the heads in the group.
bool need_kv = head_idx % qRatio == 0;
// Make sure the following code is warp uniform
need_kv = warp.shfl(need_kv, 0);
const int kv_head_idx = head_idx / qRatio;
// Ensure we don't access invalid portions of the seq_metadata
const int32_t seq_id = (valid_token) ? batch_desc.tokens_to_seq[token_idx] : 0;
const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id];
// This will give an invalid index if valid_token is false, but should never affect memory.
const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx);
T* q_row = q + token_idx * qkv_stride + head_idx * headSize;
T q_reg[vector_T];
if (need_kv) {
// The following logic assumes a linearly blocked KV cache. This means that no sparsity has
// been introduced into cache history.
const KVCacheDescriptor kv_desc = batch_desc.kv_desc;
const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size;
const int32_t mapped_kv_block_idx =
(valid_token) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0;
const int32_t kv_block_offset = global_token_idx % kv_desc.block_size;
const int32_t kv_offset =
(mapped_kv_block_idx * kv_desc.block_size + kv_block_offset) * kv_cache_stride +
kv_head_idx * headSize;
// Load indices from QKV output
T* k_row = k + token_idx * qkv_stride + kv_head_idx * headSize;
T* v_row = v + token_idx * qkv_stride + kv_head_idx * headSize;
T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T];
mem_access::load_global<kv_rot::granularity>(q_reg, q_row + base_neuron_idx, valid_token);
mem_access::load_global<kv_rot::granularity>(k_reg, k_row + base_neuron_idx, valid_token);
mem_access::load_global<kv_rot::granularity>(v_reg, v_row + base_neuron_idx, valid_token);
mem_access::load_global<kv_rot::granularity>(
inv_freq_reg, inv_freq + half_idx, load_inv_freq);
if constexpr (doRotary) {
#pragma unroll
for (int i = 0; i < vector_T; i++) {
const int head_neuron_idx = base_neuron_idx + i;
float inv_freq_flt;
if (inv_freq != nullptr) {
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
} else {
inv_freq_flt =
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
// Conversion to T and back means that both branches of this if statement
// will produce the same results if using the same algo for producing the
// freqs.
T trunc_freq = conversion::to<T>(1.0 / powf(10000.0, inv_freq_flt));
inv_freq_flt = conversion::to<float>(trunc_freq) * (float)global_token_idx;
}
float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
float q_f = conversion::to<float>(q_reg[i]);
float k_f = conversion::to<float>(k_reg[i]);
float q_rot = q_f * rotary_sign;
float k_rot = k_f * rotary_sign;
const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes);
const float k_rot_temp = head_group.shfl_xor(k_rot, half_head_lanes);
q_reg[i] =
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
k_reg[i] =
conversion::to<T>(k_f * cosf(inv_freq_flt) + k_rot_temp * sinf(inv_freq_flt));
}
}
if (valid_token) {
mem_access::store_global<kv_rot::granularity>(kv_cache + kv_offset + base_neuron_idx,
k_reg);
mem_access::store_global<kv_rot::granularity>(
kv_cache + kv_offset + base_neuron_idx + v_offset, v_reg);
}
} else {
T inv_freq_reg[vector_T];
mem_access::load_global<kv_rot::granularity>(q_reg, q_row + base_neuron_idx, valid_token);
mem_access::load_global<kv_rot::granularity>(
inv_freq_reg, inv_freq + half_idx, load_inv_freq);
if constexpr (doRotary) {
#pragma unroll
for (int i = 0; i < vector_T; i++) {
const int head_neuron_idx = base_neuron_idx + i;
float inv_freq_flt;
if (inv_freq != nullptr) {
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
} else {
inv_freq_flt =
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx;
}
float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
float q_f = conversion::to<float>(q_reg[i]);
float q_rot = q_f * rotary_sign;
const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes);
q_reg[i] =
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
}
}
}
if (valid_token && doRotary) {
mem_access::store_global<kv_rot::granularity>(q_row + base_neuron_idx, q_reg);
}
}
#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE) \
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, true> \
<<<grid, block, 0, stream>>>(kv_cache, \
q, \
k, \
v, \
inv_freq, \
batch_desc, \
qkv_stride, \
kv_cache_stride, \
v_offset, \
inv_freq_stride);
template <typename T>
void launch_kv_rotary_kernel(T* kv_cache,
T* q,
T* k,
T* v,
T* inv_freq,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
const int v_offset,
const int inv_freq_stride,
const int q_ratio,
const int head_size,
const int n_tokens,
const int n_q_heads,
cudaStream_t stream)
{
constexpr int vector_T = kv_rot::granularity / sizeof(T);
const int threads_per_head = head_size / vector_T;
const int tokens_per_block = kv_rot::threads / threads_per_head;
const dim3 block(kv_rot::threads);
const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block;
const dim3 grid(n_q_heads, token_blocks);
DISPATCH_KV_ROTARY_IMPL(1, 64)
DISPATCH_KV_ROTARY_IMPL(1, 128)
DISPATCH_KV_ROTARY_IMPL(2, 64)
DISPATCH_KV_ROTARY_IMPL(2, 128)
DISPATCH_KV_ROTARY_IMPL(4, 64)
DISPATCH_KV_ROTARY_IMPL(4, 128)
DISPATCH_KV_ROTARY_IMPL(5, 64)
DISPATCH_KV_ROTARY_IMPL(5, 128)
DISPATCH_KV_ROTARY_IMPL(8, 64)
DISPATCH_KV_ROTARY_IMPL(8, 128)
}
#define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \
template void launch_kv_rotary_kernel<TYPE>(TYPE * kv_cache, \
TYPE * q, \
TYPE * k, \
TYPE * v, \
TYPE * inv_freq, \
const BatchWrapperCPP batch_desc, \
const int qkv_stride, \
const int kv_cache_stride, \
const int v_offset, \
const int inv_freq_stride, \
const int q_ratio, \
const int head_size, \
const int n_tokens, \
const int n_q_heads, \
cudaStream_t stream);
INSTANTIATE_KV_ROTARY_KERNEL(__half)
#ifdef BF16_AVAILABLE
INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16)
#endif
#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, false><<<grid, block, 0, stream>>>( \
kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0);
template <typename T>
void launch_kv_copy_kernel(T* kv_cache,
T* q,
T* k,
T* v,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
const int v_offset,
const int q_ratio,
const int head_size,
const int n_tokens,
const int n_q_heads,
cudaStream_t stream)
{
constexpr int vector_T = kv_rot::granularity / sizeof(T);
const int threads_per_head = head_size / vector_T;
const int tokens_per_block = kv_rot::threads / threads_per_head;
const dim3 block(kv_rot::threads);
const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block;
const dim3 grid(n_q_heads, token_blocks);
DISPATCH_KV_COPY_IMPL(1, 64)
DISPATCH_KV_COPY_IMPL(1, 128)
DISPATCH_KV_COPY_IMPL(2, 64)
DISPATCH_KV_COPY_IMPL(2, 128)
DISPATCH_KV_COPY_IMPL(4, 64)
DISPATCH_KV_COPY_IMPL(4, 128)
DISPATCH_KV_COPY_IMPL(5, 64)
DISPATCH_KV_COPY_IMPL(5, 128)
DISPATCH_KV_COPY_IMPL(8, 64)
DISPATCH_KV_COPY_IMPL(8, 128)
}
#define INSTANTIATE_KV_COPY_KERNEL(TYPE) \
template void launch_kv_copy_kernel<TYPE>(TYPE * kv_cache, \
TYPE * q, \
TYPE * k, \
TYPE * v, \
const BatchWrapperCPP batch_desc, \
const int qkv_stride, \
const int kv_cache_stride, \
const int v_offset, \
const int q_ratio, \
const int head_size, \
const int n_tokens, \
const int n_q_heads, \
cudaStream_t stream);
INSTANTIATE_KV_COPY_KERNEL(__half)
#ifdef BF16_AVAILABLE
INSTANTIATE_KV_COPY_KERNEL(__nv_bfloat16)
#endif

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "ds_kernel_utils.h"
#include "ragged_dtypes.h"
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
template <typename T>
void launch_kv_rotary_kernel(T* kv_cache,
T* q,
T* k,
T* v,
T* inv_freq,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
const int v_offset,
const int inv_freq_stride,
const int q_ratio,
const int head_size,
const int n_tokens,
const int n_q_heads,
cudaStream_t stream);
template <typename T>
void launch_kv_copy_kernel(T* kv_cache,
T* q,
T* k,
T* v,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
const int v_offset,
const int q_ratio,
const int head_size,
const int n_tokens,
const int n_q_heads,
cudaStream_t stream);

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "blocked_kv_rotary.cuh"
/*
Rotary position embeddings + copy into KV cache. This implementation assumes
that the inverse frequencies should be ready from global memory rather than
synthesized in the kernel.
Arguments:
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
q: [n_tokens, n_q_heads * head_size]
k: [n_tokens, n_kv_heads * head_size]
v: [n_tokens, n_kv_heads * head_size]
inv_freq: [max_seq_len, head_size // 2]
*/
void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& inv_freq,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs);
/*
Rotary position embeddings + copy into KV cache. This implementation assumes
that the inverse frequencies should be synthesized in the kernel.
Arguments:
kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size]
q: [n_tokens, n_q_heads * head_size]
k: [n_tokens, n_kv_heads * head_size]
v: [n_tokens, n_kv_heads * head_size]
*/
void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs);
/*
Copy into linear KV cache.
*/
void linear_kv_copy(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
torch::Tensor& kv_ptrs);

Просмотреть файл

@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ....inference_utils import DtypeEnum
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ....ragged import RaggedBatchWrapper
from ... import DSKernelBase
class BlockedRotaryEmbeddings(DSKernelBase):
"""
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
before copying into a blocked KV cache.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8]
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
"""
Args:
head_size: The size of the attention head.
q_ratio: Ratio of q heads to kv heads (for GQA)
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
"""
q_ratio = n_q_heads // n_kv_heads
if head_size not in BlockedRotaryEmbeddings.supported_head_sizes:
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
head_size, BlockedRotaryEmbeddings.supported_head_sizes))
if q_ratio not in BlockedRotaryEmbeddings.supported_q_ratios:
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
q_ratio, BlockedRotaryEmbeddings.supported_q_ratios))
if not isinstance(dtype, DtypeEnum):
dtype = DtypeEnum(dtype)
if dtype not in BlockedRotaryEmbeddings.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
dtype, BlockedRotaryEmbeddings.supported_dtypes))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.kv_rotary_embeddings
self.head_size = head_size
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
"""
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
Args:
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
ragged_batch: Wrapper for the ragged batch.
"""
q = qkv[:, :self.head_size * self.n_q_heads]
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(),
ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())

Просмотреть файл

@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ....inference_utils import DtypeEnum
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ....ragged import RaggedBatchWrapper
from ... import DSKernelBase
class BlockedTrainedRotaryEmbeddings(DSKernelBase):
"""
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
before copying into a blocked KV cache.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8]
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
"""
Args:
head_size: The size of the attention head.
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
"""
q_ratio = n_q_heads // n_kv_heads
if head_size not in BlockedTrainedRotaryEmbeddings.supported_head_sizes:
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
head_size, BlockedTrainedRotaryEmbeddings.supported_head_sizes))
if q_ratio not in BlockedTrainedRotaryEmbeddings.supported_q_ratios:
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
q_ratio, BlockedTrainedRotaryEmbeddings.supported_q_ratios))
if not isinstance(dtype, DtypeEnum):
dtype = DtypeEnum(dtype)
if dtype not in BlockedTrainedRotaryEmbeddings.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
dtype, BlockedTrainedRotaryEmbeddings.supported_dtypes))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.kv_trained_rotary_embeddings
self.head_size = head_size
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper,
inverse_freqs: torch.Tensor) -> None:
"""
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
Args:
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
ragged_batch: Wrapper for the ragged batch.
inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, head_size // 2]
"""
q = qkv[:, :self.head_size * self.n_q_heads]
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
self.kernel(kv_cache, q, k, v, inverse_freqs, ragged_batch.batch_metadata_buffer(),
ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())

Просмотреть файл

@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ....inference_utils import DtypeEnum
from ....ragged import RaggedBatchWrapper
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ... import DSKernelBase
class LinearBlockedKVCopy(DSKernelBase):
"""
CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys
before copying into a blocked KV cache.
"""
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8]
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
"""
Args:
head_size: The size of the attention head.
dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16.
"""
q_ratio = n_q_heads // n_kv_heads
if head_size not in LinearBlockedKVCopy.supported_head_sizes:
raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format(
head_size, LinearBlockedKVCopy.supported_head_sizes))
if q_ratio not in LinearBlockedKVCopy.supported_q_ratios:
raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format(
q_ratio, LinearBlockedKVCopy.supported_q_ratios))
if not isinstance(dtype, DtypeEnum):
dtype = DtypeEnum(dtype)
if dtype not in LinearBlockedKVCopy.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
dtype, LinearBlockedKVCopy.supported_dtypes))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.linear_kv_copy
self.head_size = head_size
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
"""
Perform rotary embeddings on the queries and keys before copying into a blocked KV cache.
Args:
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
ragged_batch: Wrapper for the ragged batch.
"""
q = qkv[:, :self.head_size * self.n_q_heads]
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]
self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(),
ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .logits_gather import *

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "logits_gather.h"
#define DISPATCH_TO_LOGITS_GATHER(T_TYPE, C_TYPE) \
if (all_acts.options().dtype() == torch::T_TYPE) { \
launch_logits_gather((C_TYPE*)final_token_acts.data_ptr(), \
(const C_TYPE*)all_acts.data_ptr(), \
batch_metadata_raw, \
seq_metadata_raw, \
n_seqs, \
embed_dim, \
at::cuda::getCurrentCUDAStream()); \
}
/*
Logits gather will parse the ragged batch data structure and gather only the logits that
will be used for token sampling.
*/
void gather_for_logits(torch::Tensor& final_token_acts,
torch::Tensor& all_acts,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata)
{
const RaggedBatchDescriptor* batch_metadata_raw =
reinterpret_cast<const RaggedBatchDescriptor*>(batch_metadata.data_ptr());
const InflightSeqDescriptor* seq_metadata_raw =
reinterpret_cast<const InflightSeqDescriptor*>(seq_metadata.data_ptr());
const int n_seqs = final_token_acts.size(0);
const int embed_dim = final_token_acts.size(1);
TORCH_CHECK(all_acts.scalar_type() == final_token_acts.scalar_type(),
"all_acts and final_token_acts must have the same scalar type");
DISPATCH_TO_LOGITS_GATHER(kFloat, float)
DISPATCH_TO_LOGITS_GATHER(kHalf, half)
#ifdef BF16_AVAILABLE
DISPATCH_TO_LOGITS_GATHER(kBFloat16, __nv_bfloat16)
#endif
}

Просмотреть файл

@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "ds_kernel_utils.h"
#include "logits_gather.cuh"
#include "memory_access_utils.h"
#include "ragged_dtypes.h"
namespace logits_gather {
constexpr int granularity = 16;
constexpr int threads = 512;
} // namespace logits_gather
template <typename T>
__global__ void logits_gather_kernel(T* final_token_acts,
const T* token_acts,
const RaggedBatchDescriptor* ragged_batch,
const InflightSeqDescriptor* inflight_batch,
const int32_t embed_dim)
{
constexpr int T_vector = logits_gather::granularity / sizeof(T);
const int32_t seq_id = blockIdx.y;
// It's possible we've padded the output Tensor (under CG conditions)
if (seq_id >= ragged_batch->n_sequences) return;
const InflightSeqDescriptor seq = inflight_batch[seq_id];
const int final_token_idx = seq.start_idx + seq.n_tokens - 1;
const int token_offset = final_token_idx * embed_dim;
const int thread_offset =
threadIdx.x * T_vector + blockIdx.x * logits_gather::threads * T_vector;
const int final_token_offset = seq_id * embed_dim;
T reg_buf[T_vector];
if (thread_offset < embed_dim) {
mem_access::load_global<logits_gather::granularity>(
reg_buf, token_acts + token_offset + thread_offset);
mem_access::store_global<logits_gather::granularity>(
final_token_acts + final_token_offset + thread_offset, reg_buf);
}
}
template <typename T>
void launch_logits_gather(T* final_token_acts,
const T* all_acts,
const RaggedBatchDescriptor* ragged_batch,
const InflightSeqDescriptor* inflight_batch,
const int32_t n_seqs,
const int32_t embed_dim,
cudaStream_t stream)
{
constexpr int T_vector = logits_gather::granularity / sizeof(T);
constexpr int elems_per_block = logits_gather::threads * T_vector;
const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block;
const dim3 grid(parallel_blocks, n_seqs, 1);
const dim3 block(logits_gather::threads, 1, 1);
logits_gather_kernel<T><<<grid, block, 0, stream>>>(
final_token_acts, all_acts, ragged_batch, inflight_batch, embed_dim);
}
#define INSTANTIATE_FOR_TYPE(T) \
template void launch_logits_gather<T>(T * final_token_acts, \
const T* all_acts, \
const RaggedBatchDescriptor* ragged_batch, \
const InflightSeqDescriptor* inflight_batch, \
const int32_t n_seqs, \
const int32_t embed_dim, \
cudaStream_t stream);
INSTANTIATE_FOR_TYPE(float)
INSTANTIATE_FOR_TYPE(__half)
#ifdef BF16_AVAILABLE
INSTANTIATE_FOR_TYPE(__nv_bfloat16)
#endif

Просмотреть файл

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "ds_kernel_utils.h"
#include "ragged_dtypes.h"
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
template <typename T>
void launch_logits_gather(T* final_token_acts,
const T* all_acts,
const RaggedBatchDescriptor* batch_metadata,
const InflightSeqDescriptor* seq_metadata,
const int32_t n_seqs,
const int32_t embed_dim,
cudaStream_t stream);

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "logits_gather.cuh"
#include "ragged_dtypes.h"
/*
Logits gather will parse the ragged batch data structure and gather only the logits that
will be used for token sampling.
*/
void gather_for_logits(torch::Tensor& final_token_acts,
torch::Tensor& all_acts,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata);

Просмотреть файл

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ... import DSKernelBase
from deepspeed.ops.op_builder import RaggedOpsBuilder
from ....inference_utils import elem_size
from ....ragged import RaggedBatchWrapper
class RaggedLogitsGather(DSKernelBase):
"""
CUDA Kernel implementation for gather the hidden states of the final token
of each sequence. This is used to reduce the cost of the performing the unembedding.
"""
supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def __init__(self, model_dim: int, fp_dtype: torch.dtype):
"""
Parameters:
fp_dtype (torch.dtype): Data type for the input/output. Supported values
are torch.float16, torch.bfloat16, and torch.float32.
"""
if fp_dtype not in RaggedLogitsGather.supported_dtypes:
raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
fp_dtype, RaggedLogitsGather.supported_dtypes))
if elem_size(fp_dtype) * model_dim % 16 != 0:
raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(model_dim))
inf_module = RaggedOpsBuilder().load()
self.kernel = inf_module.gather_for_logits
def __call__(self, final_token_activations: torch.Tensor, all_activations: torch.Tensor,
ragged_wrapper: RaggedBatchWrapper) -> torch.Tensor:
"""
Gather the hidden states of the final token of each sequence from `all_activations` into
`final_token_activations`.
Args:
final_token_activations (torch.Tensor): Output tensor of shape [num_seqs, model_dim]
all_activations (torch.Tensor): Input tensor of shape [num_tokens, model_dim]
ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch.
"""
self.kernel(final_token_activations, all_activations, ragged_wrapper.batch_metadata_buffer(),
ragged_wrapper.inflight_seq_descriptors())
return final_token_activations

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше