зеркало из https://github.com/microsoft/DeepSpeed.git
fix MegatronLayerPolicy to be compatible with the newest ParallelTransformerLayer (#4236)
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
This commit is contained in:
Родитель
5dbc531328
Коммит
6cbf666131
|
@ -51,14 +51,21 @@ class MegatronLayerPolicy(TransformerPolicy):
|
|||
try:
|
||||
from megatron.model.transformer import ParallelTransformerLayer
|
||||
MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
|
||||
MegatronLayerPolicy.version = 1
|
||||
except ImportError:
|
||||
MegatronLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.attention.num_attention_heads, \
|
||||
self.client_module.input_layernorm.eps, \
|
||||
DEFAULT_INTERMEDIATE_SIZE
|
||||
if MegatronLayerPolicy.version == 0:
|
||||
return self.client_module.attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.attention.num_attention_heads, \
|
||||
self.client_module.input_layernorm.eps, \
|
||||
DEFAULT_INTERMEDIATE_SIZE
|
||||
else:
|
||||
return self.client_module.self_attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.self_attention.num_attention_heads, \
|
||||
self.client_module.input_layernorm.eps, \
|
||||
DEFAULT_INTERMEDIATE_SIZE
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
if self.inference:
|
||||
|
|
Загрузка…
Ссылка в новой задаче