зеркало из https://github.com/microsoft/DeepSpeed.git
Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen (#5403)
This PR adds support for Qwen1.5MoE-A2.7B models. support for https://github.com/microsoft/DeepSpeed-MII/issues/457 ### Test Code for mii pipeline: ```python import mii pipe = mii.pipeline("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B") responses = pipe("DeepSpeed is", max_new_tokens=128, do_sample=False) if pipe.is_rank_0: print(responses[0]) ``` for huggingface: ```python import mii from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig import torch tokenizer = AutoTokenizer.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B") model = AutoModelForCausalLM.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).eval() print(model) inputs = tokenizer('DeepSpeed is', return_tensors='pt') inputs = inputs.to(model.device) pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0) test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) print(test) ``` ### Qwen1.5-MoE-A2.7B Huggingface output with prompt "DeepSpeed is": ``` a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models. DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs. One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the ``` DeepSpeed-FastGen output with prompt "DeepSpeed is": ``` a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models. DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs. One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the ``` DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding: ``` a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models. DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs. One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the ``` Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Heyang Qin <heyangqin@microsoft.com> Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com>
This commit is contained in:
Родитель
23d0e0221f
Коммит
249c1db2fb
|
@ -233,6 +233,8 @@ We currently support the following model architectures in this alpha release of
|
|||
* [Phi-2](https://huggingface.co/models?other=phi-msft)
|
||||
* [Phi-3](https://huggingface.co/models?other=phi3)
|
||||
* [Qwen](https://huggingface.co/models?other=qwen)
|
||||
* [Qwen2](https://huggingface.co/models?other=qwen2)
|
||||
* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe)
|
||||
|
||||
All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from .model_implementations import (
|
|||
Phi3Policy,
|
||||
QwenPolicy,
|
||||
Qwen2Policy,
|
||||
Qwen2MoePolicy,
|
||||
)
|
||||
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
|
||||
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
|
||||
|
@ -126,6 +127,8 @@ def build_hf_engine(path: str,
|
|||
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
||||
elif model_config.model_type == "qwen2":
|
||||
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
|
||||
elif model_config.model_type == "qwen2_moe":
|
||||
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type {model_config.model_type}")
|
||||
|
||||
|
|
|
@ -11,5 +11,8 @@
|
|||
} else if (2 == N_TOP_K) { \
|
||||
constexpr int CONST_TOP_K = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (4 == N_TOP_K) { \
|
||||
constexpr int CONST_TOP_K = 4; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
|
|
@ -18,3 +18,4 @@ from .phi import *
|
|||
from .phi3 import *
|
||||
from .qwen import *
|
||||
from .qwen_v2 import *
|
||||
from .qwen_v2_moe import *
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .policy import Qwen2MoePolicy
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# Create a container object to save model-specific tensors using the policy file above.
|
||||
|
||||
from ..common_parameters import *
|
||||
from ..layer_container_base import LayerContainer
|
||||
'''
|
||||
# HF Qwen1.5-MoE-A2.7B model looks like this:
|
||||
|
||||
Qwen2MoeForCausalLM(
|
||||
(model): Qwen2MoeModel(
|
||||
(embed_tokens): Embedding(151936, 2048)
|
||||
(layers): ModuleList(
|
||||
(0-23): 24 x Qwen2MoeDecoderLayer(
|
||||
(self_attn): Qwen2MoeSdpaAttention(
|
||||
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
|
||||
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
|
||||
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
|
||||
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
|
||||
(rotary_emb): Qwen2MoeRotaryEmbedding()
|
||||
)
|
||||
(mlp): Qwen2MoeSparseMoeBlock(
|
||||
(gate): Linear(in_features=2048, out_features=60, bias=False)
|
||||
(experts): ModuleList(
|
||||
(0-59): 60 x Qwen2MoeMLP(
|
||||
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
|
||||
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
|
||||
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
|
||||
(act_fn): SiLU()
|
||||
)
|
||||
)
|
||||
(shared_expert): Qwen2MoeMLP(
|
||||
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
|
||||
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
|
||||
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
|
||||
(act_fn): SiLU()
|
||||
)
|
||||
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
|
||||
)
|
||||
(input_layernorm): Qwen2MoeRMSNorm()
|
||||
(post_attention_layernorm): Qwen2MoeRMSNorm()
|
||||
)
|
||||
)
|
||||
(norm): Qwen2MoeRMSNorm()
|
||||
)
|
||||
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
|
||||
)
|
||||
'''
|
||||
|
||||
|
||||
class Qwen2MoeTransformerContainer(LayerContainer):
|
||||
"""
|
||||
Transformer layer container for the Qwen2Moe model.
|
||||
"""
|
||||
qkv_w: UnfusedQKVParameter
|
||||
qkv_b: UnfusedQKVParameter
|
||||
attn_out_w: AttentionOutputParameter
|
||||
moe_gate: MoEGatingWeightParameter
|
||||
moe_mlp_1: UnfusedMoEGatedMLPParameter
|
||||
moe_mlp_2: UnfusedMoEMLP2Parameter
|
||||
shared_moe_mlp_1: GatedMLPParameter
|
||||
shared_moe_mlp_2: MLP2Parameter
|
||||
shared_moe_gate: MoEGatingWeightParameter
|
||||
attn_norm_gamma: NormParameter
|
||||
mlp_norm_gamma: NormParameter
|
||||
|
||||
PARAM_MAPPING = {
|
||||
"self_attn.q_proj.weight": "qkv_w.q_params",
|
||||
"self_attn.k_proj.weight": "qkv_w.k_params",
|
||||
"self_attn.v_proj.weight": "qkv_w.v_params",
|
||||
"self_attn.q_proj.bias": "qkv_b.q_params",
|
||||
"self_attn.k_proj.bias": "qkv_b.k_params",
|
||||
"self_attn.v_proj.bias": "qkv_b.v_params",
|
||||
"self_attn.o_proj.weight": "attn_out_w.params",
|
||||
"mlp.gate.weight": "moe_gate.params",
|
||||
"mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts",
|
||||
"mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts",
|
||||
"mlp.experts.*.down_proj.weight": "moe_mlp_2.experts",
|
||||
"mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params",
|
||||
"mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params",
|
||||
"mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params",
|
||||
"mlp.shared_expert_gate.weight": "shared_moe_gate.params",
|
||||
"input_layernorm.weight": "attn_norm_gamma.params",
|
||||
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
|
||||
}
|
||||
|
||||
|
||||
class Qwen2MoeNonTransformerContainer(LayerContainer):
|
||||
"""
|
||||
Non-Transformer layer container for the Qwen2Moe model.
|
||||
"""
|
||||
word_emb: EmbeddingParameter
|
||||
word_unembed: UnembedParameter
|
||||
final_norm: NormParameter
|
||||
|
||||
PARAM_MAPPING = {
|
||||
"model.embed_tokens.weight": "word_emb.params",
|
||||
"model.norm.weight": "final_norm.params",
|
||||
"lm_head.weight": "word_unembed.params",
|
||||
}
|
|
@ -0,0 +1,359 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import deepspeed.comm as dist
|
||||
|
||||
from ...allocator import empty_from
|
||||
from ...config_v2 import RaggedInferenceEngineConfig
|
||||
from ...inference_utils import ActivationType, DtypeEnum
|
||||
from ...model_implementations import *
|
||||
from ...modules.configs import *
|
||||
from ...modules.interfaces import *
|
||||
from ...modules import heuristics
|
||||
from ...ragged import RaggedBatchWrapper
|
||||
from ..inference_model_base import (
|
||||
DSModelImplementationConfig,
|
||||
MPType,
|
||||
)
|
||||
|
||||
from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer
|
||||
|
||||
|
||||
class Qwen2MoeInferenceModel(DSMoETransformerModelBase):
|
||||
"""
|
||||
Inference model implementation for Qwen2MoE models.
|
||||
"""
|
||||
|
||||
_non_transformer: Optional[Qwen2MoeNonTransformerContainer]
|
||||
"""
|
||||
Embed + unembed container. Specializing the type annotation.
|
||||
"""
|
||||
|
||||
_transformer: Optional[Iterable[Qwen2MoeTransformerContainer]]
|
||||
"""
|
||||
Per-layer transformer container. Specializing the type annotation.
|
||||
"""
|
||||
"""
|
||||
Properties ineherited from `DSInferenceModelBase`
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_sequence_length(self) -> int:
|
||||
return self._config.max_position_embeddings
|
||||
|
||||
"""
|
||||
Properties ineherited from `DSTransformerModelBase`
|
||||
"""
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self._config.num_hidden_layers
|
||||
|
||||
@property
|
||||
def model_dim(self) -> int:
|
||||
return self._config.hidden_size
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._config.vocab_size
|
||||
|
||||
@property
|
||||
def head_size(self) -> int:
|
||||
return self.model_dim // self.n_heads
|
||||
|
||||
@property
|
||||
def n_heads(self) -> int:
|
||||
return self._config.num_attention_heads
|
||||
|
||||
@property
|
||||
def intermediate_dim(self) -> int:
|
||||
return self._config.intermediate_size
|
||||
|
||||
@property
|
||||
def n_heads_kv(self) -> int:
|
||||
return self._config.num_key_value_heads
|
||||
|
||||
@property
|
||||
def activation_dtype(self) -> DtypeEnum:
|
||||
# TODO(ZonePG): bf16 inference results may be different from huggingface bf16,
|
||||
# because in rms_norm, Qwen still use float() instead of bf16
|
||||
# if self._config.torch_dtype == torch.float16:
|
||||
# return DtypeEnum.fp16
|
||||
# elif self._config.torch_dtype == torch.bfloat16:
|
||||
# return DtypeEnum.bf16
|
||||
# else:
|
||||
# raise NotImplementedError("Only fp16 and bf16 are supported")
|
||||
return DtypeEnum.fp16
|
||||
|
||||
@property
|
||||
def mlp_activation_fn(self) -> ActivationType:
|
||||
return ActivationType.SiGLU
|
||||
|
||||
@property
|
||||
def norm_type(self) -> NormTypeEnum:
|
||||
return NormTypeEnum.RMSNorm
|
||||
|
||||
@property
|
||||
def positional_embedding_type(self) -> PositionalEmbeddingType:
|
||||
return PositionalEmbeddingType.rotate_half
|
||||
|
||||
@property
|
||||
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
|
||||
return RotateHalfConfig(theta_base=self._config.rope_theta)
|
||||
|
||||
"""
|
||||
Inherited from `DSMoETransformerModelBase`
|
||||
"""
|
||||
|
||||
@property
|
||||
def n_experts(self) -> int:
|
||||
return self._config.num_experts
|
||||
|
||||
@property
|
||||
def n_top_k(self) -> int:
|
||||
return self._config.num_experts_per_tok
|
||||
|
||||
@property
|
||||
def normalize_expert_scores(self) -> bool:
|
||||
return self._config.norm_topk_prob
|
||||
|
||||
def make_moe_layer(self) -> None:
|
||||
"""
|
||||
Instantiates the MoE layer for the model. This sets the `self.moe` attribute.
|
||||
"""
|
||||
sharded_dim = sharded_intermediate_dim(self.intermediate_dim // self.n_top_k, self.tp_size, self.tp_rank)
|
||||
|
||||
moe_config = DSMoEConfig(
|
||||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
|
||||
model_dim=self.model_dim,
|
||||
intermediate_features=sharded_dim,
|
||||
activation=self.mlp_activation_fn,
|
||||
n_experts=self.n_experts,
|
||||
top_k=self.n_top_k,
|
||||
input_dtype=self.activation_dtype,
|
||||
output_dtype=self.activation_dtype,
|
||||
normalize_scores=self.normalize_expert_scores,
|
||||
)
|
||||
|
||||
self.moe = heuristics.instantiate_moe(moe_config, self._engine_config)
|
||||
|
||||
######### MLP 1 #########
|
||||
def make_shared_expert_mlp_1_layer(self) -> None:
|
||||
"""
|
||||
Instantiates the linear projection layer for the first MLP in the feedforward network.
|
||||
This sets the `self.mlp_1` attribute.
|
||||
"""
|
||||
shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank)
|
||||
|
||||
linear_config = DSLinearConfig(
|
||||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
|
||||
in_channels=self.model_dim,
|
||||
out_channels=shard_size,
|
||||
activation=self.mlp_activation_fn,
|
||||
input_dtype=self.activation_dtype,
|
||||
output_dtype=self.activation_dtype,
|
||||
)
|
||||
|
||||
self.shared_expert_mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config)
|
||||
|
||||
######### MLP 2 #########
|
||||
def make_shared_expert_mlp_2_layer(self) -> None:
|
||||
"""
|
||||
Instantiates the linear projection layer for the second MLP in the feedforward network.
|
||||
This sets the `self.mlp_2` attribute.
|
||||
"""
|
||||
shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank)
|
||||
|
||||
linear_config = DSLinearConfig(
|
||||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
|
||||
in_channels=shard_size,
|
||||
out_channels=self.model_dim,
|
||||
input_dtype=self.activation_dtype,
|
||||
output_dtype=self.activation_dtype,
|
||||
)
|
||||
|
||||
self.shared_expert_mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config)
|
||||
|
||||
######### MLP 2 #########
|
||||
def make_shared_expert_gate_layer(self) -> None:
|
||||
"""
|
||||
Instantiates the linear projection layer for the second MLP in the feedforward network.
|
||||
This sets the `self.mlp_2` attribute.
|
||||
"""
|
||||
shard_size = sharded_intermediate_dim(self.model_dim, self.tp_size, self.tp_rank)
|
||||
|
||||
linear_config = DSLinearConfig(
|
||||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
|
||||
in_channels=shard_size,
|
||||
out_channels=8,
|
||||
input_dtype=self.activation_dtype,
|
||||
output_dtype=self.activation_dtype,
|
||||
)
|
||||
|
||||
self.shared_expert_gate = heuristics.instantiate_linear(linear_config, self._engine_config)
|
||||
|
||||
def make_norm_layer(self) -> None:
|
||||
"""
|
||||
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
|
||||
|
||||
TODO(cmikeh2): In the future we'll distinguish between the different norm objects,
|
||||
but for now we'll just use the same one for all of them.
|
||||
"""
|
||||
norm_config = DSNormConfig(
|
||||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
|
||||
type=self.norm_type,
|
||||
channels=self.model_dim,
|
||||
residual_dtype=self.activation_dtype,
|
||||
input_dtype=self.activation_dtype,
|
||||
output_dtype=self.activation_dtype,
|
||||
eps=self._config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config)
|
||||
|
||||
"""
|
||||
Model implementation
|
||||
"""
|
||||
|
||||
def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig,
|
||||
base_mp_group: MPType) -> None:
|
||||
"""
|
||||
Base implementation for initialization. By default, this will initialize
|
||||
the traditional components of a transformer model:
|
||||
- Embedding
|
||||
- QKV projection
|
||||
- Self attention
|
||||
- Attention output projection
|
||||
- Feed forward network
|
||||
- Normalization
|
||||
- Unembedding
|
||||
|
||||
Arguments:
|
||||
config (DSModelImplementationConfig): Model-specific configuration. No assumptions
|
||||
should be made about this config that are not closely tied to the specific
|
||||
model implementation.
|
||||
engine_config (RaggedInferenceEngineConfig): Engine configuration.
|
||||
base_mp_group (MPType): Base communication group for Tensor-parallel inference.
|
||||
"""
|
||||
super().__init__(config, engine_config, base_mp_group)
|
||||
|
||||
self.make_norm_layer()
|
||||
self.make_qkv_layer()
|
||||
self.make_attn_layer()
|
||||
self.make_attn_out_layer()
|
||||
self.make_moe_layer()
|
||||
self.make_shared_expert_mlp_1_layer()
|
||||
self.make_shared_expert_mlp_2_layer()
|
||||
self.make_shared_expert_gate_layer()
|
||||
self.make_embedding_layer()
|
||||
self.make_unembedding_layer()
|
||||
self._kv_cache_config = None
|
||||
|
||||
"""
|
||||
Forward implementations
|
||||
"""
|
||||
|
||||
def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
|
||||
"""
|
||||
Performs the embedding lookup prior to running the transformer of the model.
|
||||
|
||||
Arguments:
|
||||
ragged_batch (RaggedBatchWrapper): The batch to embed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The embedded batch.
|
||||
"""
|
||||
embed = self.embed(ragged_batch, self._non_transformer.word_emb)
|
||||
|
||||
if embed.shape[-1] != self.model_dim:
|
||||
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")
|
||||
|
||||
return embed
|
||||
|
||||
def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
|
||||
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
|
||||
optimization to fuse the layer norm of the next layer into the current layer.
|
||||
|
||||
Arguments:
|
||||
layer_idx (int): The index of the layer to execute.
|
||||
residual (torch.Tensor): The residual tensor from the previous layer.
|
||||
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
|
||||
hidden states after pre normalization.
|
||||
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
|
||||
"""
|
||||
# TODO(cmikeh2): Distribute ragged_batch_info to all modules
|
||||
|
||||
cur_params = self._transformer[layer_idx]
|
||||
kv_cache = self.state_manager.get_cache(layer_idx)
|
||||
|
||||
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b)
|
||||
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
|
||||
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(hidden_states, group=self._base_mp_group)
|
||||
|
||||
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)
|
||||
|
||||
shared_expert_output = self.shared_expert_mlp_1(hidden_states, cur_params.shared_moe_mlp_1, b=None)
|
||||
shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None)
|
||||
shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1]
|
||||
# shared_expert_gate_output shape[-1] is 1
|
||||
shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output))
|
||||
hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1,
|
||||
cur_params.moe_mlp_2)
|
||||
hidden_states.add_(shared_expert_output)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(hidden_states, group=self._base_mp_group)
|
||||
|
||||
if layer_idx != self.num_layers - 1:
|
||||
next_params = self._transformer[layer_idx + 1]
|
||||
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
|
||||
else:
|
||||
# On last layer, we just need to perform the residual add. Adding into the residual
|
||||
# here is safe.
|
||||
residual.add_(hidden_states)
|
||||
|
||||
return residual, hidden_states
|
||||
|
||||
def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
|
||||
"""
|
||||
Performs unembedding of the hidden states to logits. This will only sample the final
|
||||
token of each sequence.
|
||||
"""
|
||||
logits = self.unembed(hidden_states,
|
||||
self._non_transformer.word_unembed,
|
||||
ragged_batch_info,
|
||||
gamma=self._non_transformer.final_norm)
|
||||
|
||||
if self.tp_size > 1:
|
||||
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
|
||||
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))
|
||||
|
||||
dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)
|
||||
|
||||
full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))
|
||||
|
||||
return full_logits
|
||||
else:
|
||||
return logits
|
||||
|
||||
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
|
||||
|
||||
residual = self._forward_embed(wrapped_batch)
|
||||
|
||||
residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None)
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch)
|
||||
|
||||
return self._forward_unembed(residual, wrapped_batch)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...config_v2 import RaggedInferenceEngineConfig
|
||||
from ..inference_policy_base import ContainerMap, InferenceV2Policy
|
||||
from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer
|
||||
from .model import Qwen2MoeInferenceModel
|
||||
|
||||
|
||||
class Qwen2MoePolicy(InferenceV2Policy):
|
||||
|
||||
def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2MoeInferenceModel:
|
||||
return Qwen2MoeInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)
|
||||
|
||||
def build_container_map(self) -> ContainerMap:
|
||||
map = ContainerMap()
|
||||
|
||||
transformer_containers = [Qwen2MoeTransformerContainer(self.model) for _ in range(self.model.num_layers)]
|
||||
|
||||
map.set_transformer_params(['model.layers'], transformer_containers)
|
||||
|
||||
map.set_non_transformer_params(Qwen2MoeNonTransformerContainer(self.model))
|
||||
|
||||
map.set_unmapped_params([])
|
||||
|
||||
return map
|
|
@ -42,7 +42,7 @@ class DSMultiGemmMoE(DSMoEBase):
|
|||
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
|
||||
return False
|
||||
|
||||
if config.top_k != 1 and config.top_k != 2:
|
||||
if config.top_k != 1 and config.top_k != 2 and config.top_k != 4:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
Загрузка…
Ссылка в новой задаче