Update falcon fused type order (#5007)

The selection of fused type depends on the order of fused_type_dict.
If put “DecoderLayer” in front of “FalconDecoderLayer”, Falcon will
still choose glmtype incorrectly, so need to put “DecoderLayer at” the
last position of fused_type_dict.

---------

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Yejing-Lai 2024-01-26 14:42:31 +08:00 коммит произвёл GitHub
Родитель e81369318e
Коммит 62afafe812
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 8 добавлений и 5 удалений

Просмотреть файл

@ -5,7 +5,6 @@
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
import re
def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
@ -39,10 +38,10 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
"BaichuanLayer": 'glmtype',
"DecoderLayer": 'glmtype',
"QWenBlock": 'qwentype',
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
}
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
@ -114,8 +113,12 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
raise ValueError("unknown fused_qkv_type")
for module_name, fused_type in fused_type_dict.items():
if re.search(module_name, module_str):
module_name_matches = [k for k in fused_type_dict.keys() if module_str in k]
if module_name_matches:
# There can be overlap with matches (e.g., "DecoderLayer" and "FalconDecoderLayer").
# We take the longest matching module_name
module_name = max(module_name_matches, key=len)
fused_type = fused_type_dict[module_name]
return _transpose_fused_qkvw(src, mp_size, fused_type, module)
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")