diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu new file mode 100644 index 000000000..4c5ef02b1 --- /dev/null +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -0,0 +1,129 @@ +#include "custom_cuda_layers.h" + +#include + +namespace cg = cooperative_groups; + +__global__ void apply_rotary_pos_emb(float* mixed_query, + float* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int id = threadIdx.x; + int gid = id >> 5; + int lane = id & 0x1f; + + unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; + unsigned offset = head_id * head_size; + + unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + + if (head_id < total_count) { + while (lane < rotary_dim) { + float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + float q = mixed_query[offset + lane]; + float k = key_layer[offset + lane]; + float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); + float q_rot = (q * rotary_sign); + float k_rot = (k * rotary_sign); + q_rot = g.shfl_xor(q_rot, 1); + k_rot = g.shfl_xor(k_rot, 1); + q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); + k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); + + mixed_query[offset + lane] = q; + key_layer[offset + lane] = k; + + lane += WARP_SIZE; + } + } +} + +__global__ void apply_rotary_pos_emb(__half* mixed_query, + __half* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count) +{ +#if __CUDA_ARCH__ >= 700 + + unsigned head_id = blockIdx.x * blockDim.y + threadIdx.y; + if (head_id < total_count) { + unsigned offset = head_id * head_size + threadIdx.x; + unsigned tid = threadIdx.x; + unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + while (tid < rotary_dim) { + float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + float q = (float)mixed_query[offset]; + float k = (float)key_layer[offset]; + float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); + float q_rot = (q * rotary_sign); + float k_rot = (k * rotary_sign); + q_rot = g.shfl_xor(q_rot, 1); + k_rot = g.shfl_xor(k_rot, 1); + q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); + k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); + + mixed_query[offset] = (__half)q; + key_layer[offset] = (__half)k; + + tid += blockDim.x; + offset += blockDim.x; + } + } +#endif +} + +template +void launch_apply_rotary_pos_emb(T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + cudaStream_t stream) +{ + int total_count = batch * num_heads * seq_len; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); + + apply_rotary_pos_emb<<>>( + mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); +} + +template void launch_apply_rotary_pos_emb(float*, + float*, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); +template void launch_apply_rotary_pos_emb<__half>(__half*, + __half*, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu old mode 100755 new mode 100644 index 10adaa6fe..d972c019c --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -264,3 +264,107 @@ template void launch_bias_residual<__half>(__half*, int, int, cudaStream_t); + +__global__ void gptj_residual_add(float* input, + float* output, + float* attn, + float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + float4* output_cast = reinterpret_cast(output); + float4* attn_cast = reinterpret_cast(attn); + float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 out = output_cast[offset]; + float4 res_vec = attn_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += (out.x + res_vec.x + bias_data.x); + data.y += (out.y + res_vec.y + bias_data.y); + data.z += (out.z + res_vec.z + bias_data.z); + data.w += (out.w + res_vec.w + bias_data.w); + + output_cast[offset] = data; + } +} + +__global__ void gptj_residual_add(__half* input, + __half* output, + __half* attn, + __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + float2* output_cast = reinterpret_cast(output); + float2* attn_cast = reinterpret_cast(attn); + + float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 out_vec = output_cast[offset]; + float2 res_vec = attn_cast[offset]; + + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* out_half = reinterpret_cast<__half2*>(&out_vec); + __half2* res_half = reinterpret_cast<__half2*>(&res_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_out = __half22float2(out_half[0]); + float2 high_out = __half22float2(out_half[1]); + + float2 low_res = __half22float2(res_half[0]); + float2 high_res = __half22float2(res_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += (low_out.x + low_res.x + low_bias.x); + low_data.y += (low_out.y + low_res.y + low_bias.y); + high_data.x += (high_out.x + high_res.x + high_bias.x); + high_data.y += (high_out.y + high_res.y + high_bias.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + output_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_gptj_residual_add(T* input, + T* output, + T* attn, + T* bias, + int hidden_dim, + int batch, + cudaStream_t stream) +{ + int total_count = batch * hidden_dim / 4; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); + + gptj_residual_add<<>>( + input, output, attn, bias, total_count, hidden_dim / 4); +} + +template void +launch_gptj_residual_add(float*, float*, float*, float*, int, int, cudaStream_t); +template void +launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index b587b6233..f0ab158db 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -204,16 +204,19 @@ at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& b } template -void qkv_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) +at::Tensor qkv_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) { auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); + + // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); + float alpha = (T)1.0; float gemm_beta = (T)0.0; int bsz = input.size(0) * input.size(1); @@ -236,16 +239,17 @@ void qkv_unfused_cublas(at::Tensor& output, weight.size(1), bsz, Context::Instance().GetCurrentStream()); + return inp_norm; } template -at::Tensor ds_qkv_gemm(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) +std::vector ds_qkv_gemm(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -256,9 +260,10 @@ at::Tensor ds_qkv_gemm(at::Tensor& input, auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); int bsz = input_cont.size(0) * input_cont.size(1); - qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); + auto inp_norm = + qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); - return output; + return {output, inp_norm}; } template @@ -592,6 +597,126 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, return {output, residual_add}; } +template +at::Tensor fused_gemm_gelu(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& weight_out) +{ + // cudaStreamWaitEvent( + // Context::Instance().GetCurrentStream(true), Context::Instance().GetCompEvent(1), 0); + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto intermediate = + at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)intermediate.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + launch_bias_gelu((T*)intermediate.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight_out.size(1), + bsz, + intermediate.size(2), + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + (T*)intermediate.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + // cudaEventRecord(Context::Instance().GetCompEvent(2), + // Context::Instance().GetCurrentStream(true)); + return output; +} + +void gptj_residual_add(at::Tensor& output, + at::Tensor& input, + at::Tensor& attention_output, + at::Tensor& output_b) +{ + int bsz = input.size(0) * input.size(1); + int hidden_size = input.size(2); + // cudaStreamWaitEvent( + // Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0); + if (input.scalar_type() == at::kFloat) + launch_gptj_residual_add((float*)input.data_ptr(), + (float*)output.data_ptr(), + (float*)attention_output.data_ptr(), + (float*)output_b.data_ptr(), + hidden_size, + bsz, + Context::Instance().GetCurrentStream()); + else + launch_gptj_residual_add<__half>((__half*)input.data_ptr(), + (__half*)output.data_ptr(), + (__half*)attention_output.data_ptr(), + (__half*)output_b.data_ptr(), + hidden_size, + bsz, + Context::Instance().GetCurrentStream()); +} + +std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, + at::Tensor& key_layer, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads) +{ + auto query_cont = mixed_query.contiguous(); + auto key_cont = key_layer.contiguous(); + + unsigned bsz = mixed_query.size(0); + unsigned head_size = mixed_query.size(2) / num_heads; + unsigned seq_len = mixed_query.size(1); + + if (mixed_query.scalar_type() == at::kFloat) + launch_apply_rotary_pos_emb((float*)query_cont.data_ptr(), + (float*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + Context::Instance().GetCurrentStream()); + else + launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), + (__half*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + Context::Instance().GetCurrentStream()); + return {query_cont, key_cont}; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); @@ -627,4 +752,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("linear_layer_int8", &ds_linear_layer_int8<__half>, "DeepSpeed linear_layer with int8 (CUDA)"); + m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu, "DeepSpeed mlp with fp32 (CUDA)"); + m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("gptj_residual_add", &gptj_residual_add, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); } diff --git a/csrc/transformer/inference/includes/context.h b/csrc/transformer/inference/includes/context.h old mode 100755 new mode 100644 index 4385bd7d5..de0848eb7 --- a/csrc/transformer/inference/includes/context.h +++ b/csrc/transformer/inference/includes/context.h @@ -40,7 +40,7 @@ inline int DS_GET_BLOCKS(const int N) class Context { public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0) + Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0) { curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); curandSetPseudoRandomGeneratorSeed(_gen, 123); @@ -50,12 +50,16 @@ public: throw std::runtime_error(message); } cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); + cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync)); + cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync)); } virtual ~Context() { cublasDestroy(_cublasHandle); cudaFree(_workspace); + cudaEventDestroy(_comp1_event); + cudaEventDestroy(_comp2_event); } static Context& Instance() @@ -81,13 +85,19 @@ public: curandGenerator_t& GetRandGenerator() { return _gen; } - cudaStream_t GetCurrentStream() + cudaStream_t GetCurrentStream(bool other_stream = false) { // get current pytorch stream. + if (other_stream) { + if (!_stream) _stream = at::cuda::getStreamFromPool(true); + return _stream; + } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); return stream; } + cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } + cublasHandle_t GetCublasHandle() { return _cublasHandle; } std::pair IncrementOffset(uint64_t offset_inc) @@ -108,5 +118,11 @@ private: uint64_t _seed; uint64_t _curr_offset; size_t _workSpaceSize; + + cudaEvent_t _comp1_event; + cudaEvent_t _comp2_event; + + cudaStream_t _stream; + std::vector> _gemm_algos; }; diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index dfb43a07a..1d4c7ced1 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -77,3 +77,22 @@ void launch_dequantize(T* output, unsigned groups, unsigned merge_count, cudaStream_t stream); + +template +void launch_gptj_residual_add(T* input, + T* output, + T* attn, + T* bias, + int batch, + int head_size, + cudaStream_t stream); +template +void launch_apply_rotary_pos_emb(T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + cudaStream_t stream); diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index b6e7f7c79..148ca2bd9 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -2,7 +2,7 @@ import copy import torch import deepspeed import deepspeed.ops.transformer as transformer_inference -from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy +from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy from .replace_policy import replace_policies from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE from ..runtime.weight_quantizer import WeightQuantization @@ -200,11 +200,12 @@ def replace_transformer_layer(orig_layer_impl, _4hh_w = _4hh_w.half() if quantize or fp16: - dense_b = dense_b.half() + qkvb = qkvb if qkvb is None else qkvb.half() + dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = _h4h_b.half() _4hh_b = _4hh_b.half() - attn_nw = attn_nw.half() - attn_nb = attn_nb.half() + attn_nw = attn_nw if dense_b is None else attn_nw.half() + attn_nb = attn_nb if dense_b is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() @@ -216,7 +217,9 @@ def replace_transformer_layer(orig_layer_impl, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, - 'layer_norm_eps') else 1e-12, + 'layer_norm_eps') else + (config.layer_norm_epsilon if hasattr(config, + 'layer_norm_epsilon') else 1e-12), fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, @@ -227,7 +230,10 @@ def replace_transformer_layer(orig_layer_impl, if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, - 'window_size') else 1)) + 'window_size') else 1), + rotary_dim=(config.rotary_dim if hasattr(config, + 'rotary_dim') else -1), + mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy)) if quantize and quantize_settings is not None: (quantization_scales, @@ -278,18 +284,10 @@ def replace_transformer_layer(orig_layer_impl, attn_block = new_module.attention attn_block.attn_qkvw.data = mp_replace.qkv_copy(attn_block.attn_qkvw.data, qkvw) - - if qkvb is not None: - if fp16: - qkvb = qkvb.half() - attn_block.attn_qkvb.data = mp_replace.qkv_copy( - attn_block.attn_qkvb.data, - qkvb) - else: - attn_block.attn_qkvb = qkvb + attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb.data, qkvb) attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w) - attn_block.attn_ob.data = mp_replace.copy(attn_block.attn_ob.data, dense_b) + attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob.data, dense_b) mpl_block = new_module.mlp mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data, _h4h_w) @@ -297,8 +295,10 @@ def replace_transformer_layer(orig_layer_impl, mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data, _4hh_w) mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data, _4hh_b) - new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device()) - new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device()) + new_module.mlp.attn_nw = attn_nw if dense_b is None else attn_nw.to( + torch.cuda.current_device()) + new_module.mlp.attn_nb = attn_nb if dense_b is None else attn_nb.to( + torch.cuda.current_device()) new_module.norm_w.data = input_nw.to(torch.cuda.current_device()) new_module.norm_b.data = input_nb.to(torch.cuda.current_device()) else: diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index cda2a685d..b0bd238ea 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -143,6 +143,50 @@ class HFGPTNEOLayerPolicy(DSPolicy): self.client_module.ln_1.bias.data +class HFGPTJLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + super().__init__(inference, scale_attention=True) + self.client_module = client_module + try: + import transformers + HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock + except: + HFGPTJLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attn.q_proj.weight.data.shape[1], \ + self.client_module.attn.num_attention_heads + + def attention(self): + qw = self.client_module.attn.q_proj.weight.data + kw = self.client_module.attn.k_proj.weight.data + vw = self.client_module.attn.v_proj.weight.data + + qkvw = torch.cat((qw, kw, vw), dim=0) + + return self.linear_layer, \ + qkvw, \ + None, \ + self.client_module.attn.out_proj.weight.data, \ + None, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.fc_in.weight.data, \ + self.client_module.mlp.fc_in.bias.data, \ + self.client_module.mlp.fc_out.weight.data, \ + self.client_module.mlp.fc_out.bias.data + + def layerNorm(self): + return None, \ + None, \ + self.client_module.ln_1.weight.data, \ + self.client_module.ln_1.bias.data + + class MegatronLayerPolicy(DSPolicy): _orig_layer_class = None @@ -234,6 +278,7 @@ class HFGPT2LayerPolicy(DSPolicy): replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, + HFGPTJLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, ] diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index f5b11f352..57fc3018c 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -226,19 +226,21 @@ def top1gating(logits: Tensor, exp_selection_uniform_map[logits.device] = uniform mask1_rand = mask1 * uniform(mask1.shape) + else: + mask1_rand = mask1 - assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." - top_idx = _top_idx(mask1_rand, capacity) + top_idx = _top_idx(mask1_rand, capacity) - new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) - mask1 = new_mask1 + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 - if use_tutel: - # Tutel doesn't support index values masked with zero - # so we need to replace masked indices with -1 - indices_mask = mask1.sum(dim=1) * num_experts - 1 - indices1_s = torch.min(indices1_s, indices_mask) + if use_tutel: + # Tutel doesn't support index values masked with zero + # so we need to replace masked indices with -1 + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) # Compute locations in capacity buffer if use_tutel: diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 36d2a7d43..b3715e43d 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -66,7 +66,9 @@ class DeepSpeedInferenceConfig(TransformerConfig): triangular_masking=True, local_attention=False, window_size=256, - return_tuple=True): + rotary_dim=-1, + return_tuple=True, + mlp_after_attn=True): super(DeepSpeedInferenceConfig, self).__init__( hidden_size, @@ -85,7 +87,9 @@ class DeepSpeedInferenceConfig(TransformerConfig): self.triangular_masking = triangular_masking self.local_attention = local_attention self.window_size = window_size + self.rotary_dim = rotary_dim self.return_tuple = return_tuple + self.mlp_after_attn = mlp_after_attn @classmethod def from_dict(cls, json_object): @@ -180,6 +184,14 @@ class DeepSpeedSelfAttentionFunction(Function): unfused_mode = not config.specialized_mode or \ mixed_query.shape[1] >= 32 or head_size > 128 + if config.rotary_dim > 0: + mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( + mixed_query, + key_layer, + config.rotary_dim, + 0 if layer_past is None else layer_past[0].shape[-2], + num_attention_heads_per_partition) + if layer_past is not None: past_key, past_value = layer_past if unfused_mode: @@ -189,21 +201,23 @@ class DeepSpeedSelfAttentionFunction(Function): value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) + presents = (key_layer, value_layer) + if unfused_mode: mixed_query = _transpose_for_scores(mixed_query, False, True) - key_layer1 = _transpose_for_scores( + key_layer = _transpose_for_scores( key_layer, True, True) / (norm_factor if config.scale_attention else 1.0) - value_layer1 = _transpose_for_scores(value_layer, False, True) + value_layer = _transpose_for_scores(value_layer, False, True) if layer_past is None: attn_key_value = score_context_func( mixed_query, - (key_layer1 if unfused_mode else key_layer), + key_layer, torch.empty(1), - (input_mask), - (value_layer1 if unfused_mode else value_layer), + input_mask, + value_layer, torch.empty(1), num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), @@ -215,11 +229,11 @@ class DeepSpeedSelfAttentionFunction(Function): else: attn_key_value = score_context_func( mixed_query, - (key_layer1 if unfused_mode else past_key.type_as(key_layer)), - (key_layer1 if unfused_mode else key_layer), - (input_mask), - (value_layer1 if unfused_mode else past_value.type_as(value_layer)), - (value_layer1 if unfused_mode else value_layer), + (key_layer if unfused_mode else past_key.type_as(key_layer)), + key_layer, + input_mask, + (value_layer if unfused_mode else past_value.type_as(value_layer)), + value_layer, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), @@ -235,7 +249,7 @@ class DeepSpeedSelfAttentionFunction(Function): # Transpose Context context_layer = _transpose_for_context(context_layer) - return context_layer, key_layer, value_layer + return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ @@ -255,10 +269,10 @@ class DeepSpeedSelfAttentionFunction(Function): norm_b, config.epsilon, (attn_qkvb is not None)) - context_layer, key_layer, value_layer = compute_attention(qkv_out, input_mask) + context_layer, key_layer, value_layer = compute_attention(qkv_out[0], input_mask) output = vector_matmul_func(context_layer, attn_ow) - return output, key_layer, value_layer, context_layer + return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm def selfAttention_int8(): if not config.pre_layer_norm: @@ -290,12 +304,12 @@ class DeepSpeedSelfAttentionFunction(Function): if config.q_int8: output, key_layer, value_layer, context_layer = selfAttention_int8() else: - output, key_layer, value_layer, context_layer = selfAttention_fp() + output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp() if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1: torch.distributed.all_reduce(output, group=mp_group) - return (output, key_layer, value_layer, context_layer) + return (output, key_layer, value_layer, context_layer, inp_norm) @staticmethod def backward(ctx, grad_output, grad_output1, grad_output2, grad_output3): @@ -396,9 +410,12 @@ class DeepSpeedMLPFunction(Function): output_w, q_scales, q_groups, - merge_count): + merge_count, + mlp_gemm_func, + fused_gemm_gelu, + vector_matmul_func, + bias_residual_func): if config.q_int8: - (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8( input, @@ -418,29 +435,26 @@ class DeepSpeedMLPFunction(Function): q_groups, (merge_count)) else: - mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \ - inference_cuda_module.mlp_gemm_fp32 - vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ - inference_cuda_module.vector_matmul_fp32 - (intermediate, - residual_add) = mlp_gemm_func(input, - residual, - bias, - inter_w, - inter_b, - attn_nw, - attn_nb, - config.epsilon, - config.pre_layer_norm) - output = vector_matmul_func(intermediate, output_w) + if attn_nw is None: + output = fused_gemm_gelu(input, inter_w, inter_b, output_w) + else: + (intermediate, + residual_add) = mlp_gemm_func(input, + residual, + bias, + inter_w, + inter_b, + attn_nw, + attn_nb, + config.epsilon, + config.pre_layer_norm) + output = vector_matmul_func(intermediate, output_w) if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1: torch.distributed.all_reduce(output, group=mp_group) - bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ - inference_cuda_module.bias_residual_fp32 - - output = bias_residual_func(output, residual_add, output_b) + if attn_nw is not None: + output = bias_residual_func(output, residual_add, output_b) return output @@ -479,6 +493,15 @@ class DeepSpeedMLP(nn.Module): self.merge_count = int(math.log2(merge_count)) self.mp_group = mp_group + self.mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \ + inference_cuda_module.mlp_gemm_fp32 + self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ + inference_cuda_module.vector_matmul_fp32 + self.fused_gemm_gelu = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ + inference_cuda_module.fused_gemm_gelu_fp32 + + self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ + inference_cuda_module.bias_residual_fp32 def forward(self, input, residual, bias): return DeepSpeedMLPFunction.apply(input, @@ -494,7 +517,11 @@ class DeepSpeedMLP(nn.Module): self.output_w, self.q_scales, self.q_groups, - self.merge_count) + self.merge_count, + self.mlp_gemm_func, + self.fused_gemm_gelu, + self.vector_matmul_func, + self.bias_residual_func) class DeepSpeedTransformerInference(nn.Module): @@ -528,21 +555,6 @@ class DeepSpeedTransformerInference(nn.Module): self.config = config self.config.layer_id = DeepSpeedTransformerInference.layer_id DeepSpeedTransformerInference.layer_id += 1 - self.attention = DeepSpeedSelfAttention(self.config, - mp_group, - quantize_scales, - quantize_groups, - merge_count, - qkv_merging) - self.mlp = DeepSpeedMLP(self.config, - mp_group, - quantize_scales, - quantize_groups, - merge_count, - mlp_extra_grouping) - - self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size)) - self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) global inference_cuda_module global specialized_mode @@ -560,6 +572,22 @@ class DeepSpeedTransformerInference(nn.Module): self.config.specialized_mode = specialized_mode print("DeepSpeed Transformer Inference config is ", self.config.__dict__) + self.attention = DeepSpeedSelfAttention(self.config, + mp_group, + quantize_scales, + quantize_groups, + merge_count, + qkv_merging) + self.mlp = DeepSpeedMLP(self.config, + mp_group, + quantize_scales, + quantize_groups, + merge_count, + mlp_extra_grouping) + + self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size)) + self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) + def forward(self, input, input_mask=None, @@ -596,13 +624,15 @@ class DeepSpeedTransformerInference(nn.Module): self.norm_b) if get_present: - attention_output, p_key, p_value, _ = attention_output - presents = (p_key, p_value) + presents = (attention_output[1], attention_output[2]) elif output_attentions: - attention_output, _, _, context_output = attention_output - else: - attention_output, _, _, _ = attention_output - output = self.mlp(attention_output, input, self.attention.attn_ob) + context_output = attention_output[3] + + output = self.mlp( + attention_output[0] + if self.config.mlp_after_attn else attention_output[-1], + input, + self.attention.attn_ob) if not self.config.pre_layer_norm: ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \ @@ -612,8 +642,13 @@ class DeepSpeedTransformerInference(nn.Module): self.norm_b, self.config.epsilon) - if input_type != output.dtype: - output = output.to(input_type) + if not self.config.mlp_after_attn: + inference_cuda_module.gptj_residual_add(output, + input, + attention_output[0], + self.mlp.output_b) + + output = output.to(input_type) if get_present: output = (output, presents) diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 94db63711..f04bad230 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -19,6 +19,7 @@ class InferenceBuilder(CUDAOpBuilder): 'csrc/transformer/inference/csrc/normalize.cu', 'csrc/transformer/inference/csrc/softmax.cu', 'csrc/transformer/inference/csrc/dequantize.cu', + 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu', ] def include_paths(self):