Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen (#5403)

This PR adds support for Qwen1.5MoE-A2.7B models.

support for https://github.com/microsoft/DeepSpeed-MII/issues/457

### Test Code

for mii pipeline:
```python
import mii

pipe = mii.pipeline("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
responses = pipe("DeepSpeed is", max_new_tokens=128, do_sample=False)
if pipe.is_rank_0:
    print(responses[0])
```
for huggingface:
```python
import mii

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
tokenizer = AutoTokenizer.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
model = AutoModelForCausalLM.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).eval()
print(model)
inputs = tokenizer('DeepSpeed is', return_tensors='pt')
inputs = inputs.to(model.device)
pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0)
test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
print(test)
```

### Qwen1.5-MoE-A2.7B
Huggingface output with prompt "DeepSpeed is":
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```
DeepSpeed-FastGen output with prompt "DeepSpeed is":
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```

DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding:
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com>
This commit is contained in:
Perry Zou 2024-08-02 01:27:24 +08:00 коммит произвёл GitHub
Родитель 23d0e0221f
Коммит 249c1db2fb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 508 добавлений и 1 удалений

Просмотреть файл

@ -233,6 +233,8 @@ We currently support the following model architectures in this alpha release of
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Phi-3](https://huggingface.co/models?other=phi3)
* [Qwen](https://huggingface.co/models?other=qwen)
* [Qwen2](https://huggingface.co/models?other=qwen2)
* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe)
All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Просмотреть файл

@ -23,6 +23,7 @@ from .model_implementations import (
Phi3Policy,
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
@ -126,6 +127,8 @@ def build_hf_engine(path: str,
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2_moe":
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Просмотреть файл

@ -11,5 +11,8 @@
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} \
}()

Просмотреть файл

@ -18,3 +18,4 @@ from .phi import *
from .phi3 import *
from .qwen import *
from .qwen_v2 import *
from .qwen_v2_moe import *

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .policy import Qwen2MoePolicy

Просмотреть файл

@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Create a container object to save model-specific tensors using the policy file above.
from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen1.5-MoE-A2.7B model looks like this:
Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
(embed_tokens): Embedding(151936, 2048)
(layers): ModuleList(
(0-23): 24 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Linear(in_features=2048, out_features=60, bias=False)
(experts): ModuleList(
(0-59): 60 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(act_fn): SiLU()
)
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm()
(post_attention_layernorm): Qwen2MoeRMSNorm()
)
)
(norm): Qwen2MoeRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)
'''
class Qwen2MoeTransformerContainer(LayerContainer):
"""
Transformer layer container for the Qwen2Moe model.
"""
qkv_w: UnfusedQKVParameter
qkv_b: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
moe_gate: MoEGatingWeightParameter
moe_mlp_1: UnfusedMoEGatedMLPParameter
moe_mlp_2: UnfusedMoEMLP2Parameter
shared_moe_mlp_1: GatedMLPParameter
shared_moe_mlp_2: MLP2Parameter
shared_moe_gate: MoEGatingWeightParameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter
PARAM_MAPPING = {
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.q_proj.bias": "qkv_b.q_params",
"self_attn.k_proj.bias": "qkv_b.k_params",
"self_attn.v_proj.bias": "qkv_b.v_params",
"self_attn.o_proj.weight": "attn_out_w.params",
"mlp.gate.weight": "moe_gate.params",
"mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts",
"mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts",
"mlp.experts.*.down_proj.weight": "moe_mlp_2.experts",
"mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params",
"mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params",
"mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params",
"mlp.shared_expert_gate.weight": "shared_moe_gate.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}
class Qwen2MoeNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Qwen2Moe model.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm: NormParameter
PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm.params",
"lm_head.weight": "word_unembed.params",
}

Просмотреть файл

@ -0,0 +1,359 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Iterable, Optional, Tuple
import torch
import deepspeed.comm as dist
from ...allocator import empty_from
from ...config_v2 import RaggedInferenceEngineConfig
from ...inference_utils import ActivationType, DtypeEnum
from ...model_implementations import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper
from ..inference_model_base import (
DSModelImplementationConfig,
MPType,
)
from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer
class Qwen2MoeInferenceModel(DSMoETransformerModelBase):
"""
Inference model implementation for Qwen2MoE models.
"""
_non_transformer: Optional[Qwen2MoeNonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""
_transformer: Optional[Iterable[Qwen2MoeTransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
"""
Properties ineherited from `DSInferenceModelBase`
"""
@property
def max_sequence_length(self) -> int:
return self._config.max_position_embeddings
"""
Properties ineherited from `DSTransformerModelBase`
"""
@property
def num_layers(self) -> int:
return self._config.num_hidden_layers
@property
def model_dim(self) -> int:
return self._config.hidden_size
@property
def vocab_size(self) -> int:
return self._config.vocab_size
@property
def head_size(self) -> int:
return self.model_dim // self.n_heads
@property
def n_heads(self) -> int:
return self._config.num_attention_heads
@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size
@property
def n_heads_kv(self) -> int:
return self._config.num_key_value_heads
@property
def activation_dtype(self) -> DtypeEnum:
# TODO(ZonePG): bf16 inference results may be different from huggingface bf16,
# because in rms_norm, Qwen still use float() instead of bf16
# if self._config.torch_dtype == torch.float16:
# return DtypeEnum.fp16
# elif self._config.torch_dtype == torch.bfloat16:
# return DtypeEnum.bf16
# else:
# raise NotImplementedError("Only fp16 and bf16 are supported")
return DtypeEnum.fp16
@property
def mlp_activation_fn(self) -> ActivationType:
return ActivationType.SiGLU
@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.RMSNorm
@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half
@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)
"""
Inherited from `DSMoETransformerModelBase`
"""
@property
def n_experts(self) -> int:
return self._config.num_experts
@property
def n_top_k(self) -> int:
return self._config.num_experts_per_tok
@property
def normalize_expert_scores(self) -> bool:
return self._config.norm_topk_prob
def make_moe_layer(self) -> None:
"""
Instantiates the MoE layer for the model. This sets the `self.moe` attribute.
"""
sharded_dim = sharded_intermediate_dim(self.intermediate_dim // self.n_top_k, self.tp_size, self.tp_rank)
moe_config = DSMoEConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
model_dim=self.model_dim,
intermediate_features=sharded_dim,
activation=self.mlp_activation_fn,
n_experts=self.n_experts,
top_k=self.n_top_k,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
normalize_scores=self.normalize_expert_scores,
)
self.moe = heuristics.instantiate_moe(moe_config, self._engine_config)
######### MLP 1 #########
def make_shared_expert_mlp_1_layer(self) -> None:
"""
Instantiates the linear projection layer for the first MLP in the feedforward network.
This sets the `self.mlp_1` attribute.
"""
shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank)
linear_config = DSLinearConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
in_channels=self.model_dim,
out_channels=shard_size,
activation=self.mlp_activation_fn,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
)
self.shared_expert_mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config)
######### MLP 2 #########
def make_shared_expert_mlp_2_layer(self) -> None:
"""
Instantiates the linear projection layer for the second MLP in the feedforward network.
This sets the `self.mlp_2` attribute.
"""
shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank)
linear_config = DSLinearConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
in_channels=shard_size,
out_channels=self.model_dim,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
)
self.shared_expert_mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config)
######### MLP 2 #########
def make_shared_expert_gate_layer(self) -> None:
"""
Instantiates the linear projection layer for the second MLP in the feedforward network.
This sets the `self.mlp_2` attribute.
"""
shard_size = sharded_intermediate_dim(self.model_dim, self.tp_size, self.tp_rank)
linear_config = DSLinearConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
in_channels=shard_size,
out_channels=8,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
)
self.shared_expert_gate = heuristics.instantiate_linear(linear_config, self._engine_config)
def make_norm_layer(self) -> None:
"""
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
TODO(cmikeh2): In the future we'll distinguish between the different norm objects,
but for now we'll just use the same one for all of them.
"""
norm_config = DSNormConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
type=self.norm_type,
channels=self.model_dim,
residual_dtype=self.activation_dtype,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
eps=self._config.rms_norm_eps,
)
self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config)
"""
Model implementation
"""
def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig,
base_mp_group: MPType) -> None:
"""
Base implementation for initialization. By default, this will initialize
the traditional components of a transformer model:
- Embedding
- QKV projection
- Self attention
- Attention output projection
- Feed forward network
- Normalization
- Unembedding
Arguments:
config (DSModelImplementationConfig): Model-specific configuration. No assumptions
should be made about this config that are not closely tied to the specific
model implementation.
engine_config (RaggedInferenceEngineConfig): Engine configuration.
base_mp_group (MPType): Base communication group for Tensor-parallel inference.
"""
super().__init__(config, engine_config, base_mp_group)
self.make_norm_layer()
self.make_qkv_layer()
self.make_attn_layer()
self.make_attn_out_layer()
self.make_moe_layer()
self.make_shared_expert_mlp_1_layer()
self.make_shared_expert_mlp_2_layer()
self.make_shared_expert_gate_layer()
self.make_embedding_layer()
self.make_unembedding_layer()
self._kv_cache_config = None
"""
Forward implementations
"""
def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.
Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)
if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")
return embed
def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
optimization to fuse the layer norm of the next layer into the current layer.
Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
# TODO(cmikeh2): Distribute ragged_batch_info to all modules
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)
if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)
shared_expert_output = self.shared_expert_mlp_1(hidden_states, cur_params.shared_moe_mlp_1, b=None)
shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None)
shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1]
# shared_expert_gate_output shape[-1] is 1
shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output))
hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1,
cur_params.moe_mlp_2)
hidden_states.add_(shared_expert_output)
if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)
if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
residual.add_(hidden_states)
return residual, hidden_states
def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)
if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))
dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)
full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))
return full_logits
else:
return logits
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)
residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None)
for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch)
return self._forward_unembed(residual, wrapped_batch)

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Any
from ...config_v2 import RaggedInferenceEngineConfig
from ..inference_policy_base import ContainerMap, InferenceV2Policy
from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer
from .model import Qwen2MoeInferenceModel
class Qwen2MoePolicy(InferenceV2Policy):
def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2MoeInferenceModel:
return Qwen2MoeInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)
def build_container_map(self) -> ContainerMap:
map = ContainerMap()
transformer_containers = [Qwen2MoeTransformerContainer(self.model) for _ in range(self.model.num_layers)]
map.set_transformer_params(['model.layers'], transformer_containers)
map.set_non_transformer_params(Qwen2MoeNonTransformerContainer(self.model))
map.set_unmapped_params([])
return map

Просмотреть файл

@ -42,7 +42,7 @@ class DSMultiGemmMoE(DSMoEBase):
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False
if config.top_k != 1 and config.top_k != 2:
if config.top_k != 1 and config.top_k != 2 and config.top_k != 4:
return False
return True