From aca7fc549a415e3e8e0f69ea8ce58735b178b4d0 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Tue, 8 Jun 2021 11:44:59 -0700 Subject: [PATCH] Add local attention for GPT-Neo model architecture (#1114) * fix links for inference tutorial * Fix automatic injection. Add the local-attention for GPT-Neo * fix the inference for generation of large sequences (>1K & <32K) * fix format Co-authored-by: Jeff Rasley --- .../transformer/inference/csrc/pt_binding.cpp | 21 +- csrc/transformer/inference/csrc/softmax.cu | 248 +++++++++++++----- .../inference/includes/custom_cuda_layers.h | 2 + deepspeed/inference/engine.py | 5 + deepspeed/module_inject/replace_module.py | 42 +-- .../inference/transformer_inference.py | 21 +- 6 files changed, 242 insertions(+), 97 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 5ce859909..f13a5cbd4 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -13,7 +13,9 @@ template at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, bool triangular, - bool recompute) + bool recompute, + bool local_attention, + int window_size) { auto attn_scores_c = attn_scores.contiguous(); int bsz = attn_scores_c.size(0); @@ -25,6 +27,8 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, (T*)attn_mask.data_ptr(), triangular, recompute, + local_attention, + window_size, bsz, heads, seq_len, @@ -47,7 +51,9 @@ void attention_unfused(at::Tensor& prev_key_cont, int& heads, float& norm_factor, bool triangular, - bool recompute) + bool recompute, + bool local_attention, + int window_size) { auto options = at::TensorOptions() .dtype(query_cont.options().dtype()) @@ -75,7 +81,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * soft_len, bsz * heads, CUBLAS_GEMM_DEFAULT_TENSOR_OP); - attn_score = ds_softmax(attn_score, attn_mask, triangular, recompute); + attn_score = + ds_softmax(attn_score, attn_mask, triangular, recompute, local_attention, window_size); alpha = 1.0; cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), k, @@ -105,7 +112,9 @@ std::vector ds_softmax_context(at::Tensor& query, int heads, float norm_factor, bool merging, - bool triangular) + bool triangular, + bool local_attention, + int window_size) { auto query_cont = query.contiguous(); auto prev_key_cont = prev_key.contiguous(); @@ -138,7 +147,9 @@ std::vector ds_softmax_context(at::Tensor& query, heads, norm_factor, (triangular && (new_size == 0)), - (new_size == 0)); + (new_size == 0), + local_attention, + window_size); return {output, prev_key, prev_value}; } diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index cee509965..3ffad01b6 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -6,10 +6,8 @@ #include #include -#define Attn_Threads 128 -#define Reduce_Threads 32 -#define attn_warps 4 -#define MAX_ATTN_REG 4 // MAX Head Size 256 +#define ATTN_THREADS 1024 +#define MAX_REG_SIZE 8 #define minus_infinity (-1 * std::numeric_limits::infinity()) @@ -26,31 +24,40 @@ void CheckCudaErrorAux(const char* file, unsigned line) namespace cg = cooperative_groups; -template __global__ void attn_softmax_v2(__half* vals, __half* mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int total_count, int heads, int sequence_length, int num_seq, - float scale) + float scale, + int iterations, + int reduceWidth) { #if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); + cg::thread_block_tile g = cg::tiled_partition(b); - float2 low_data[tbSeq]; - float2 high_data[tbSeq]; + float2 low_data[MAX_REG_SIZE]; + float2 high_data[MAX_REG_SIZE]; __half2 h_scale = __float2half2_rn(scale); int wid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; - int iter_offset = blockIdx.x * (blockDim.x >> 5) + wid; + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); if (iter_offset < total_count) { vals += (iter_offset * sequence_length); @@ -58,20 +65,33 @@ __global__ void attn_softmax_v2(__half* vals, int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); int seq_id = iter_offset % num_seq; int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + float max_val = minus_infinity; - for (int i = 0; i < tbSeq; i++) { - int data_id = i * (WARP_SIZE << 2) + (lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) { + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { - low_data[i].x = __half2float(vals[data_id]); - low_data[i].y = (!triangular || ((data_id + 1) <= seq_id)) + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) ? __half2float(vals[data_id + 1]) : minus_infinity; - high_data[i].x = (!triangular || ((data_id + 2) <= seq_id)) + high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) ? __half2float(vals[data_id + 2]) : minus_infinity; - high_data[i].y = (!triangular || ((data_id + 3) <= seq_id)) + high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) ? __half2float(vals[data_id + 3]) : minus_infinity; if (mask && !triangular && recompute) { @@ -81,12 +101,15 @@ __global__ void attn_softmax_v2(__half* vals, high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); } } else { - low_data[i].x = __half2float(vals[data_id]); - low_data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && + (data_id + 1) > window_stride) && (data_id + 1) < sequence_length) ? __half2float(vals[data_id + 1]) : minus_infinity; - high_data[i].x = (((!triangular || (data_id + 2) <= seq_id)) && + high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && + (data_id + 2) > window_stride) && (data_id + 2) < sequence_length) ? __half2float(vals[data_id + 2]) : minus_infinity; @@ -112,13 +135,29 @@ __global__ void attn_softmax_v2(__half* vals, } } - for (int i = 1; i < tbSize; i *= 2) { + for (int i = 1; i < WARP_SIZE; i *= 2) { auto temp = g.shfl_xor(max_val, i); max_val = (temp > max_val ? temp : max_val); } + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + float sum = 0; - for (int i = 0; i < tbSeq; i++) { + for (int i = 0; i < iterations; i++) { low_data[i].x = __expf(low_data[i].x - max_val); low_data[i].y = __expf(low_data[i].y - max_val); high_data[i].x = __expf(high_data[i].x - max_val); @@ -127,12 +166,24 @@ __global__ void attn_softmax_v2(__half* vals, sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); } - for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i); + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } sum += 1e-6; - for (int i = 0; i < tbSeq; i++) { - int data_id = i * (WARP_SIZE << 2) + (lane << 2); + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); if (data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { @@ -151,46 +202,69 @@ __global__ void attn_softmax_v2(__half* vals, #endif } -template __global__ void attn_softmax_v2(float* vals, float* attn_mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int total_count, int heads, int sequence_length, int num_seq, - float scale) + float scale, + int iterations, + int reduceWidth) { cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); + cg::thread_block_tile g = cg::tiled_partition(b); - float4 data[tbSeq]; + float4 data[MAX_REG_SIZE]; int wid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; - int iter_offset = blockIdx.x * warp_num + wid; + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); if (iter_offset < total_count) { vals += (iter_offset * sequence_length); int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); int seq_id = iter_offset % num_seq; int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + float max_val = minus_infinity; - for (int i = 0; i < tbSeq; i++) { - int data_id = i * (WARP_SIZE << 2) + (lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) { + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { - data[i].x = (vals[data_id]); - data[i].y = (!triangular || ((data_id + 1) <= seq_id)) ? (vals[data_id + 1]) - : minus_infinity; - data[i].z = (!triangular || ((data_id + 2) <= seq_id)) ? (vals[data_id + 2]) - : minus_infinity; - data[i].w = (!triangular || ((data_id + 3) <= seq_id)) ? (vals[data_id + 3]) - : minus_infinity; + data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); + data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) + ? vals[data_id + 1] + : minus_infinity; + data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) + ? vals[data_id + 2] + : minus_infinity; + data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) + ? vals[data_id + 3] + : minus_infinity; if (attn_mask && !triangular && recompute) { data[i].x += attn_mask[data_id + mask_offset]; data[i].y += attn_mask[data_id + mask_offset + 1]; @@ -198,13 +272,13 @@ __global__ void attn_softmax_v2(float* vals, data[i].w += attn_mask[data_id + mask_offset + 3]; } } else { - data[i].x = (vals[data_id]); + data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && - (data_id + 1) < sequence_length) + (data_id + 1) > window_stride && (data_id + 1) < sequence_length) ? (vals[data_id + 1]) : minus_infinity; data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && - (data_id + 2) < sequence_length) + (data_id + 2) > window_stride && (data_id + 2) < sequence_length) ? (vals[data_id + 2]) : minus_infinity; data[i].w = minus_infinity; @@ -228,13 +302,29 @@ __global__ void attn_softmax_v2(float* vals, } } - for (int i = 1; i < tbSize; i *= 2) { + for (int i = 1; i < WARP_SIZE; i *= 2) { auto temp = g.shfl_xor(max_val, i); max_val = (temp > max_val ? temp : max_val); } + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + float sum = 0; - for (int i = 0; i < tbSeq; i++) { + for (int i = 0; i < iterations; i++) { data[i].x = __expf(data[i].x - max_val); data[i].y = __expf(data[i].y - max_val); data[i].z = __expf(data[i].z - max_val); @@ -243,12 +333,24 @@ __global__ void attn_softmax_v2(float* vals, sum += (data[i].x + data[i].y + data[i].z + data[i].w); } - for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i); + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } sum += 1e-6; - for (int i = 0; i < tbSeq; i++) { - int data_id = i * (WARP_SIZE << 2) + (lane << 2); + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); if (data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { @@ -271,6 +373,8 @@ void launch_attn_softmax_v2(T* vals, T* mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int batch_size, int heads, int num_seq, @@ -279,36 +383,36 @@ void launch_attn_softmax_v2(T* vals, cudaStream_t stream) { int total_count = batch_size * heads * num_seq; - dim3 grid_dim((total_count - 1) / 32 + 1); - dim3 block_dim(1024); - if (sequence_length <= 128) - attn_softmax_v2<32, 1><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); - else if (sequence_length <= 256) - attn_softmax_v2<32, 2><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); - else if (sequence_length <= 512) - attn_softmax_v2<32, 4><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); - else if (sequence_length <= 1024) - attn_softmax_v2<32, 8><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); - else if (sequence_length <= 2048) - attn_softmax_v2<32, 16><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); - else if (sequence_length <= 4096) - attn_softmax_v2<32, 32><<>>( - vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale); + dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); + dim3 block_dim(ATTN_THREADS); + + const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; + const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; + + if (sequence_length <= 32768) + attn_softmax_v2<<>>(vals, + mask, + triangular, + recompute, + local_attention, + window_size, + total_count, + heads, + sequence_length, + num_seq, + scale, + iterations, + reduce_width); else - throw std::runtime_error( - "Unsupport Seq_Length! Check the restriction of the max_threads and " - "max_thread_iterations!"); + throw std::runtime_error("Unsupport Seq_Length!"); } template void launch_attn_softmax_v2(float* vals, float* mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int batch_size, int heads, int num_seq, @@ -319,6 +423,8 @@ template void launch_attn_softmax_v2(__half* vals, __half* mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int batch_size, int heads, int num_seq, diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index b2264c5c2..b544517fa 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -18,6 +18,8 @@ void launch_attn_softmax_v2(T* vals, T* mask, bool triangular, bool recompute, + bool local_attention, + int window_size, int batch_size, int heads, int num_seq, diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 765d39b7f..42ec654a7 100644 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -78,8 +78,12 @@ class InferenceEngine(Module): def _get_model_config_generate(self): if hasattr(self.module, 'config'): self.config = self.module.config + else: + self.config = None if hasattr(self.module, 'generate'): self.generate = self.module.generate + else: + self.generate = None def _create_model_parallel_group(self): # Call the init process @@ -134,6 +138,7 @@ class InferenceEngine(Module): policy=injection_policy, mp_size=self.mp_world_size, mp_group=self.mp_group, + config=self.config, fp16=(self.dtype == torch.half), training=False, quantize=(self.dtype == torch.int8), diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 31f16cce8..1cc7e63ff 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -90,7 +90,7 @@ def replace_transformer_layer(orig_layer_impl, model, policy=None, micro_batch_size=-1, - bert_config=None, + config=None, seed=-1, hidden_size=-1, num_attention_heads=-1, @@ -110,7 +110,7 @@ def replace_transformer_layer(orig_layer_impl, model (torch.nn.Module): user's nn.module representing their model policy: shows the policy for mapping from the orig_layer_impl to transformer parameters micro_batch_size (int): micro batch size per gpu used during training/eval - bert_config (dict): model config containing hidden size, attention heads, etc. + config (dict): model config containing hidden size, attention heads, etc. seed (int): random seed value max_seq_length (int): max sequence length for training hidden_size (int): hidden dimension @@ -155,7 +155,6 @@ def replace_transformer_layer(orig_layer_impl, _4hh_w = _4hh_w.half() if quantize or fp16: - qkvb = qkvb.half() dense_b = dense_b.half() _h4h_b = _h4h_b.half() _4hh_b = _4hh_b.half() @@ -175,7 +174,12 @@ def replace_transformer_layer(orig_layer_impl, mp_size=mp_size, q_int8=quantize, encoder_decoder=(True if policy_cls is HFBertLayerPolicy else False), - triangular_masking=(policy_cls is not HFBertLayerPolicy)) + triangular_masking=(policy_cls is not HFBertLayerPolicy), + local_attention=((config.attention_layers[layer_id] == "local") + if hasattr(config, + 'attention_layers') else False), + window_size=(config.window_size if hasattr(config, + 'window_size') else 1)) if quantize and quantize_settings is not None: (quantization_scales, @@ -208,8 +212,8 @@ def replace_transformer_layer(orig_layer_impl, new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape - # linear layer is created is created with [input, output] shape - # we transpose it here to reduce inference cost! + # linear layer is created with [input, output] shape + # transpose it here to reduce inference cost! def transpose(data): data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) @@ -228,6 +232,7 @@ def replace_transformer_layer(orig_layer_impl, qkvw) if qkvb is not None: + qkvb = qkvb.half() attn_block.attn_qkvb.data = mp_replace.qkv_copy( attn_block.attn_qkvb.data, qkvb) @@ -250,12 +255,12 @@ def replace_transformer_layer(orig_layer_impl, else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size, - hidden_size=bert_config.hidden_size, - heads=bert_config.num_attention_heads, - attn_dropout_ratio=bert_config.attention_probs_dropout_prob, - hidden_dropout_ratio=bert_config.hidden_dropout_prob, - num_hidden_layers=bert_config.num_hidden_layers, - initializer_range=bert_config.initializer_range, + hidden_size=config.hidden_size, + heads=config.num_attention_heads, + attn_dropout_ratio=config.attention_probs_dropout_prob, + hidden_dropout_ratio=config.hidden_dropout_prob, + num_hidden_layers=config.num_hidden_layers, + initializer_range=config.initializer_range, seed=seed, fp16=fp16, pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln), @@ -302,20 +307,20 @@ def replace_transformer_layer(orig_layer_impl, _replace_policy=policy) -def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False): +def revert_transformer_layer(orig_layer_impl, model, config, preln=False): """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced, e.g., transformers.modeling_bert.BertLayer. model (torch.nn.Module): user's nn.module representing their model - bert_config (dict): model config containing hidden size, attention heads, etc. + config (dict): model config containing hidden size, attention heads, etc. Returns: Updated nn.module with original bert-style transformer layers """ def replace_fn(child, _replace_policy, layer_id): #from turing.nvidia_modelingpreln import BertLayer - orig_module = orig_layer_impl(bert_config) + orig_module = orig_layer_impl(config) # copy relevant state from child -> original module qkvw = child.attn_qkvw.data @@ -389,8 +394,11 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): for plcy in replace_policies: # instantiate a throw-away policy in order to populate the _orig_layer_class _ = plcy(None) - assert plcy._orig_layer_class != None - policy.update({plcy._orig_layer_class: (replace_fn, plcy)}) + if plcy._orig_layer_class is not None: + policy.update({plcy._orig_layer_class: (replace_fn, plcy)}) + assert len(policy.items()) > 0,\ + "No default policy found! Please specifiy your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\ + "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py" replaced_module, _ = _replace_module(model, policy) return replaced_module diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 0d41fe2b6..40a145851 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -74,7 +74,9 @@ class DeepSpeedInferenceConfig(TransformerConfig): stochastic_mode=False, encoder_decoder=False, scale_attention=True, - triangular_masking=True): + triangular_masking=True, + local_attention=False, + window_size=256): super(DeepSpeedInferenceConfig, self).__init__( hidden_size, @@ -92,6 +94,8 @@ class DeepSpeedInferenceConfig(TransformerConfig): self.scale_attention = scale_attention self.specialized_mode = None self.triangular_masking = triangular_masking + self.local_attention = local_attention + self.window_size = window_size @classmethod def from_dict(cls, json_object): @@ -218,7 +222,9 @@ class DeepSpeedSelfAttentionFunction(Function): num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), - config.triangular_masking) + config.triangular_masking, + config.local_attention, + config.window_size) else: attn_key_value = score_context_func( mixed_query, @@ -230,8 +236,10 @@ class DeepSpeedSelfAttentionFunction(Function): num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), - config.triangular_masking) - + config.triangular_masking, + config.local_attention, + config.window_size) + #import pdb;pdb.set_trace() if unfused_mode: context_layer, _, _ = attn_key_value else: @@ -522,6 +530,8 @@ class DeepSpeedTransformerInference(nn.Module): of a Transformer layer. We use this feature for quantization to reduce the convergence impact for specific downstream tasks. """ + layer_id = 0 + def __init__(self, config, mp_group=None, @@ -533,6 +543,9 @@ class DeepSpeedTransformerInference(nn.Module): super(DeepSpeedTransformerInference, self).__init__() self.config = config + self.config.layer_id = DeepSpeedTransformerInference.layer_id + DeepSpeedTransformerInference.layer_id += 1 + self.attention = DeepSpeedSelfAttention(config, mp_group, quantize_scales,