Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Reza Yazdani 2022-01-07 18:40:31 -08:00 коммит произвёл GitHub
Родитель 7e857aab9a
Коммит 289c3f9ba4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 587 добавлений и 107 удалений

Просмотреть файл

@ -0,0 +1,129 @@
#include "custom_cuda_layers.h"
#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
unsigned head_id = blockIdx.x * blockDim.y + threadIdx.y;
if (head_id < total_count) {
unsigned offset = head_id * head_size + threadIdx.x;
unsigned tid = threadIdx.x;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
while (tid < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset];
float k = (float)key_layer[offset];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset] = (__half)q;
key_layer[offset] = (__half)k;
tid += blockDim.x;
offset += blockDim.x;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);

104
csrc/transformer/inference/csrc/gelu.cu Executable file → Normal file
Просмотреть файл

@ -264,3 +264,107 @@ template void launch_bias_residual<__half>(__half*,
int,
int,
cudaStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += (out.x + res_vec.x + bias_data.x);
data.y += (out.y + res_vec.y + bias_data.y);
data.z += (out.z + res_vec.z + bias_data.z);
data.w += (out.w + res_vec.w + bias_data.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
int total_count,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += (low_out.x + low_res.x + low_bias.x);
low_data.y += (low_out.y + low_res.y + low_bias.y);
high_data.x += (high_out.x + high_res.x + high_bias.x);
high_data.y += (high_out.y + high_res.y + high_bias.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
int hidden_dim,
int batch,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, total_count, hidden_dim / 4);
}
template void
launch_gptj_residual_add<float>(float*, float*, float*, float*, int, int, cudaStream_t);
template void
launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t);

Просмотреть файл

@ -204,16 +204,19 @@ at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& b
}
template <typename T>
void qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
int bsz = input.size(0) * input.size(1);
@ -236,16 +239,17 @@ void qkv_unfused_cublas(at::Tensor& output,
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
@ -256,9 +260,10 @@ at::Tensor ds_qkv_gemm(at::Tensor& input,
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
auto inp_norm =
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
return output;
return {output, inp_norm};
}
template <typename T>
@ -592,6 +597,126 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& weight_out)
{
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(true), Context::Instance().GetCompEvent(1), 0);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
void gptj_residual_add(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
Context::Instance().GetCurrentStream());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
Context::Instance().GetCurrentStream());
return {query_cont, key_cont};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
@ -627,4 +752,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("gptj_residual_add", &gptj_residual_add, "DeepSpeed mlp with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
}

20
csrc/transformer/inference/includes/context.h Executable file → Normal file
Просмотреть файл

@ -40,7 +40,7 @@ inline int DS_GET_BLOCKS(const int N)
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
@ -50,12 +50,16 @@ public:
throw std::runtime_error(message);
}
cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync));
}
virtual ~Context()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
cudaEventDestroy(_comp1_event);
cudaEventDestroy(_comp2_event);
}
static Context& Instance()
@ -81,13 +85,19 @@ public:
curandGenerator_t& GetRandGenerator() { return _gen; }
cudaStream_t GetCurrentStream()
cudaStream_t GetCurrentStream(bool other_stream = false)
{
// get current pytorch stream.
if (other_stream) {
if (!_stream) _stream = at::cuda::getStreamFromPool(true);
return _stream;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
@ -108,5 +118,11 @@ private:
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
cudaEvent_t _comp1_event;
cudaEvent_t _comp2_event;
cudaStream_t _stream;
std::vector<std::array<int, 3>> _gemm_algos;
};

Просмотреть файл

@ -77,3 +77,22 @@ void launch_dequantize(T* output,
unsigned groups,
unsigned merge_count,
cudaStream_t stream);
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
int batch,
int head_size,
cudaStream_t stream);
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream);

Просмотреть файл

@ -2,7 +2,7 @@ import copy
import torch
import deepspeed
import deepspeed.ops.transformer as transformer_inference
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy
from .replace_policy import replace_policies
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE
from ..runtime.weight_quantizer import WeightQuantization
@ -200,11 +200,12 @@ def replace_transformer_layer(orig_layer_impl,
_4hh_w = _4hh_w.half()
if quantize or fp16:
dense_b = dense_b.half()
qkvb = qkvb if qkvb is None else qkvb.half()
dense_b = dense_b if dense_b is None else dense_b.half()
_h4h_b = _h4h_b.half()
_4hh_b = _4hh_b.half()
attn_nw = attn_nw.half()
attn_nb = attn_nb.half()
attn_nw = attn_nw if dense_b is None else attn_nw.half()
attn_nb = attn_nb if dense_b is None else attn_nb.half()
input_nw = input_nw.half()
input_nb = input_nb.half()
@ -216,7 +217,9 @@ def replace_transformer_layer(orig_layer_impl,
heads=num_attention_heads,
layer_norm_eps=config.layer_norm_eps if hasattr(
config,
'layer_norm_eps') else 1e-12,
'layer_norm_eps') else
(config.layer_norm_epsilon if hasattr(config,
'layer_norm_epsilon') else 1e-12),
fp16=fp16,
pre_layer_norm=preln,
mp_size=mp_size,
@ -227,7 +230,10 @@ def replace_transformer_layer(orig_layer_impl,
if hasattr(config,
'attention_layers') else False),
window_size=(config.window_size if hasattr(config,
'window_size') else 1))
'window_size') else 1),
rotary_dim=(config.rotary_dim if hasattr(config,
'rotary_dim') else -1),
mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy))
if quantize and quantize_settings is not None:
(quantization_scales,
@ -278,18 +284,10 @@ def replace_transformer_layer(orig_layer_impl,
attn_block = new_module.attention
attn_block.attn_qkvw.data = mp_replace.qkv_copy(attn_block.attn_qkvw.data,
qkvw)
if qkvb is not None:
if fp16:
qkvb = qkvb.half()
attn_block.attn_qkvb.data = mp_replace.qkv_copy(
attn_block.attn_qkvb.data,
qkvb)
else:
attn_block.attn_qkvb = qkvb
attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb.data, qkvb)
attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w)
attn_block.attn_ob.data = mp_replace.copy(attn_block.attn_ob.data, dense_b)
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob.data, dense_b)
mpl_block = new_module.mlp
mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data, _h4h_w)
@ -297,8 +295,10 @@ def replace_transformer_layer(orig_layer_impl,
mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data, _4hh_w)
mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data, _4hh_b)
new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device())
new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device())
new_module.mlp.attn_nw = attn_nw if dense_b is None else attn_nw.to(
torch.cuda.current_device())
new_module.mlp.attn_nb = attn_nb if dense_b is None else attn_nb.to(
torch.cuda.current_device())
new_module.norm_w.data = input_nw.to(torch.cuda.current_device())
new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
else:

Просмотреть файл

@ -143,6 +143,50 @@ class HFGPTNEOLayerPolicy(DSPolicy):
self.client_module.ln_1.bias.data
class HFGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
try:
import transformers
HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock
except:
HFGPTJLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.data.shape[1], \
self.client_module.attn.num_attention_heads
def attention(self):
qw = self.client_module.attn.q_proj.weight.data
kw = self.client_module.attn.k_proj.weight.data
vw = self.client_module.attn.v_proj.weight.data
qkvw = torch.cat((qw, kw, vw), dim=0)
return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.out_proj.weight.data, \
None, \
self.scale_attention
def mlp(self):
return self.linear_layer, \
self.client_module.mlp.fc_in.weight.data, \
self.client_module.mlp.fc_in.bias.data, \
self.client_module.mlp.fc_out.weight.data, \
self.client_module.mlp.fc_out.bias.data
def layerNorm(self):
return None, \
None, \
self.client_module.ln_1.weight.data, \
self.client_module.ln_1.bias.data
class MegatronLayerPolicy(DSPolicy):
_orig_layer_class = None
@ -234,6 +278,7 @@ class HFGPT2LayerPolicy(DSPolicy):
replace_policies = [
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
]

Просмотреть файл

@ -226,19 +226,21 @@ def top1gating(logits: Tensor,
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity)
top_idx = _top_idx(mask1_rand, capacity)
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
# Compute locations in capacity buffer
if use_tutel:

Просмотреть файл

@ -66,7 +66,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
triangular_masking=True,
local_attention=False,
window_size=256,
return_tuple=True):
rotary_dim=-1,
return_tuple=True,
mlp_after_attn=True):
super(DeepSpeedInferenceConfig,
self).__init__(
hidden_size,
@ -85,7 +87,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.triangular_masking = triangular_masking
self.local_attention = local_attention
self.window_size = window_size
self.rotary_dim = rotary_dim
self.return_tuple = return_tuple
self.mlp_after_attn = mlp_after_attn
@classmethod
def from_dict(cls, json_object):
@ -180,6 +184,14 @@ class DeepSpeedSelfAttentionFunction(Function):
unfused_mode = not config.specialized_mode or \
mixed_query.shape[1] >= 32 or head_size > 128
if config.rotary_dim > 0:
mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb(
mixed_query,
key_layer,
config.rotary_dim,
0 if layer_past is None else layer_past[0].shape[-2],
num_attention_heads_per_partition)
if layer_past is not None:
past_key, past_value = layer_past
if unfused_mode:
@ -189,21 +201,23 @@ class DeepSpeedSelfAttentionFunction(Function):
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer),
dim=-2)
presents = (key_layer, value_layer)
if unfused_mode:
mixed_query = _transpose_for_scores(mixed_query, False, True)
key_layer1 = _transpose_for_scores(
key_layer = _transpose_for_scores(
key_layer,
True,
True) / (norm_factor if config.scale_attention else 1.0)
value_layer1 = _transpose_for_scores(value_layer, False, True)
value_layer = _transpose_for_scores(value_layer, False, True)
if layer_past is None:
attn_key_value = score_context_func(
mixed_query,
(key_layer1 if unfused_mode else key_layer),
key_layer,
torch.empty(1),
(input_mask),
(value_layer1 if unfused_mode else value_layer),
input_mask,
value_layer,
torch.empty(1),
num_attention_heads_per_partition,
(1 / norm_factor if config.scale_attention else 1.0),
@ -215,11 +229,11 @@ class DeepSpeedSelfAttentionFunction(Function):
else:
attn_key_value = score_context_func(
mixed_query,
(key_layer1 if unfused_mode else past_key.type_as(key_layer)),
(key_layer1 if unfused_mode else key_layer),
(input_mask),
(value_layer1 if unfused_mode else past_value.type_as(value_layer)),
(value_layer1 if unfused_mode else value_layer),
(key_layer if unfused_mode else past_key.type_as(key_layer)),
key_layer,
input_mask,
(value_layer if unfused_mode else past_value.type_as(value_layer)),
value_layer,
num_attention_heads_per_partition,
(1 / norm_factor if config.scale_attention else 1.0),
(not unfused_mode),
@ -235,7 +249,7 @@ class DeepSpeedSelfAttentionFunction(Function):
# Transpose Context
context_layer = _transpose_for_context(context_layer)
return context_layer, key_layer, value_layer
return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer
def selfAttention_fp():
vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
@ -255,10 +269,10 @@ class DeepSpeedSelfAttentionFunction(Function):
norm_b,
config.epsilon,
(attn_qkvb is not None))
context_layer, key_layer, value_layer = compute_attention(qkv_out, input_mask)
context_layer, key_layer, value_layer = compute_attention(qkv_out[0], input_mask)
output = vector_matmul_func(context_layer, attn_ow)
return output, key_layer, value_layer, context_layer
return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm
def selfAttention_int8():
if not config.pre_layer_norm:
@ -290,12 +304,12 @@ class DeepSpeedSelfAttentionFunction(Function):
if config.q_int8:
output, key_layer, value_layer, context_layer = selfAttention_int8()
else:
output, key_layer, value_layer, context_layer = selfAttention_fp()
output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
torch.distributed.all_reduce(output, group=mp_group)
return (output, key_layer, value_layer, context_layer)
return (output, key_layer, value_layer, context_layer, inp_norm)
@staticmethod
def backward(ctx, grad_output, grad_output1, grad_output2, grad_output3):
@ -396,9 +410,12 @@ class DeepSpeedMLPFunction(Function):
output_w,
q_scales,
q_groups,
merge_count):
merge_count,
mlp_gemm_func,
fused_gemm_gelu,
vector_matmul_func,
bias_residual_func):
if config.q_int8:
(intermediate,
residual_add) = inference_cuda_module.mlp_gemm_int8(
input,
@ -418,29 +435,26 @@ class DeepSpeedMLPFunction(Function):
q_groups,
(merge_count))
else:
mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \
inference_cuda_module.mlp_gemm_fp32
vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
inference_cuda_module.vector_matmul_fp32
(intermediate,
residual_add) = mlp_gemm_func(input,
residual,
bias,
inter_w,
inter_b,
attn_nw,
attn_nb,
config.epsilon,
config.pre_layer_norm)
output = vector_matmul_func(intermediate, output_w)
if attn_nw is None:
output = fused_gemm_gelu(input, inter_w, inter_b, output_w)
else:
(intermediate,
residual_add) = mlp_gemm_func(input,
residual,
bias,
inter_w,
inter_b,
attn_nw,
attn_nb,
config.epsilon,
config.pre_layer_norm)
output = vector_matmul_func(intermediate, output_w)
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
torch.distributed.all_reduce(output, group=mp_group)
bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
inference_cuda_module.bias_residual_fp32
output = bias_residual_func(output, residual_add, output_b)
if attn_nw is not None:
output = bias_residual_func(output, residual_add, output_b)
return output
@ -479,6 +493,15 @@ class DeepSpeedMLP(nn.Module):
self.merge_count = int(math.log2(merge_count))
self.mp_group = mp_group
self.mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \
inference_cuda_module.mlp_gemm_fp32
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
inference_cuda_module.vector_matmul_fp32
self.fused_gemm_gelu = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
inference_cuda_module.fused_gemm_gelu_fp32
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
inference_cuda_module.bias_residual_fp32
def forward(self, input, residual, bias):
return DeepSpeedMLPFunction.apply(input,
@ -494,7 +517,11 @@ class DeepSpeedMLP(nn.Module):
self.output_w,
self.q_scales,
self.q_groups,
self.merge_count)
self.merge_count,
self.mlp_gemm_func,
self.fused_gemm_gelu,
self.vector_matmul_func,
self.bias_residual_func)
class DeepSpeedTransformerInference(nn.Module):
@ -528,21 +555,6 @@ class DeepSpeedTransformerInference(nn.Module):
self.config = config
self.config.layer_id = DeepSpeedTransformerInference.layer_id
DeepSpeedTransformerInference.layer_id += 1
self.attention = DeepSpeedSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
qkv_merging)
self.mlp = DeepSpeedMLP(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
global inference_cuda_module
global specialized_mode
@ -560,6 +572,22 @@ class DeepSpeedTransformerInference(nn.Module):
self.config.specialized_mode = specialized_mode
print("DeepSpeed Transformer Inference config is ", self.config.__dict__)
self.attention = DeepSpeedSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
qkv_merging)
self.mlp = DeepSpeedMLP(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
def forward(self,
input,
input_mask=None,
@ -596,13 +624,15 @@ class DeepSpeedTransformerInference(nn.Module):
self.norm_b)
if get_present:
attention_output, p_key, p_value, _ = attention_output
presents = (p_key, p_value)
presents = (attention_output[1], attention_output[2])
elif output_attentions:
attention_output, _, _, context_output = attention_output
else:
attention_output, _, _, _ = attention_output
output = self.mlp(attention_output, input, self.attention.attn_ob)
context_output = attention_output[3]
output = self.mlp(
attention_output[0]
if self.config.mlp_after_attn else attention_output[-1],
input,
self.attention.attn_ob)
if not self.config.pre_layer_norm:
ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \
@ -612,8 +642,13 @@ class DeepSpeedTransformerInference(nn.Module):
self.norm_b,
self.config.epsilon)
if input_type != output.dtype:
output = output.to(input_type)
if not self.config.mlp_after_attn:
inference_cuda_module.gptj_residual_add(output,
input,
attention_output[0],
self.mlp.output_b)
output = output.to(input_type)
if get_present:
output = (output, presents)

Просмотреть файл

@ -19,6 +19,7 @@ class InferenceBuilder(CUDAOpBuilder):
'csrc/transformer/inference/csrc/normalize.cu',
'csrc/transformer/inference/csrc/softmax.cu',
'csrc/transformer/inference/csrc/dequantize.cu',
'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
]
def include_paths(self):