fix MegatronLayerPolicy to be compatible with the newest ParallelTransformerLayer (#4236)

Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
This commit is contained in:
Dino Chen 2023-08-31 07:28:43 +08:00 коммит произвёл GitHub
Родитель 5dbc531328
Коммит 6cbf666131
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 11 добавлений и 4 удалений

Просмотреть файл

@ -51,14 +51,21 @@ class MegatronLayerPolicy(TransformerPolicy):
try:
from megatron.model.transformer import ParallelTransformerLayer
MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
MegatronLayerPolicy.version = 1
except ImportError:
MegatronLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
if MegatronLayerPolicy.version == 0:
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
else:
return self.client_module.self_attention.query_key_value.weight.shape[1], \
self.client_module.self_attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
def attention(self, enable_training=False):
if self.inference: