зеркало из https://github.com/microsoft/DeepSpeed.git
enable starcode((kv_head=1)) autotp (#4896)
Hi, This PR is aim to enable starcode(kv_head=1) autotp. Please kindly review. Thanks~ Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
This commit is contained in:
Родитель
f4f31317ed
Коммит
85132adc31
|
@ -4,7 +4,7 @@
|
|||
# DeepSpeed Team
|
||||
import torch
|
||||
from deepspeed.utils.logging import warning_once
|
||||
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads
|
||||
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
|
||||
import re
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
|
|||
|
||||
|
||||
def require_tp_fused_qkvw(name, mp_size):
|
||||
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack']
|
||||
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn']
|
||||
|
||||
if mp_size == 1:
|
||||
return False
|
||||
|
@ -38,6 +38,7 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
|
|||
"MptBlock": 'glmtype',
|
||||
"BaichuanLayer": 'glmtype',
|
||||
"DecoderLayer": 'glmtype',
|
||||
"GPTBigCodeBlock": 'bigcodetype'
|
||||
}
|
||||
|
||||
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
|
||||
|
@ -74,6 +75,14 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
|
|||
split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
|
||||
return split_fusedqkv[gpu_index]
|
||||
|
||||
def _bigcode_type_transpose(input, mp_size):
|
||||
n_embd = get_n_embd()
|
||||
q = input[:n_embd]
|
||||
kv = input[n_embd:]
|
||||
shape = q.shape
|
||||
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
|
||||
return torch.cat((split_q[gpu_index], kv), dim=0)
|
||||
|
||||
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
|
||||
|
||||
# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
|
||||
|
@ -87,6 +96,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
|
|||
return _codegen_type_transpose(src, mp_size)
|
||||
elif fused_qkv_type == 'glmtype':
|
||||
return _glm_type_transpose(src, mp_size)
|
||||
elif fused_qkv_type == 'bigcodetype':
|
||||
return _bigcode_type_transpose(src, mp_size)
|
||||
|
||||
raise ValueError("unknown fused_qkv_type")
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from .replace_policy import replace_policies, generic_policies
|
|||
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.module_inject.tp_shard import set_num_kv_heads
|
||||
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
|
||||
|
||||
from .load_checkpoint import load_model_with_checkpoint
|
||||
import time
|
||||
|
@ -278,6 +278,18 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
|||
# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
|
||||
set_num_kv_heads(num_kv_heads)
|
||||
|
||||
# 4.1 Get n_embd
|
||||
n_embd = None
|
||||
multi_query_n_embd_names = ['n_embd']
|
||||
for name in multi_query_n_embd_names:
|
||||
if hasattr(model_config, name):
|
||||
n_embd = getattr(model_config, name)
|
||||
if n_embd != None:
|
||||
break
|
||||
|
||||
# 4.2 set n_embd
|
||||
set_n_embd(n_embd)
|
||||
|
||||
# 5. Set linear policies
|
||||
_autotp.update_linear_policies()
|
||||
|
||||
|
|
|
@ -12,6 +12,11 @@ def set_num_kv_heads(num):
|
|||
num_kv_heads = num
|
||||
|
||||
|
||||
def set_n_embd(num):
|
||||
global n_embd
|
||||
n_embd = num
|
||||
|
||||
|
||||
def get_num_kv_heads():
|
||||
global num_kv_heads
|
||||
return num_kv_heads
|
||||
|
@ -32,6 +37,11 @@ def get_shard_size(total_size, mp_size, rank=None):
|
|||
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"
|
||||
|
||||
|
||||
def get_n_embd():
|
||||
global n_embd
|
||||
return n_embd
|
||||
|
||||
|
||||
def get_shard_size_list(total_size, mp_size):
|
||||
shard_sizes = []
|
||||
for i in range(mp_size):
|
||||
|
|
Загрузка…
Ссылка в новой задаче