зеркало из https://github.com/microsoft/DeepSpeed.git
Fix expert grad scaling problem with ZeRO optimizer (#6546)
Fix [#6545] work: - expert gradient average: divide edp_world_size -> divide dp_world_size - unit test: make sure model with different dp/ep has same expert gradient --------- Co-authored-by: wangyiou <wangyiou@xiaohongshu.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
bf03f48352
Коммит
b647fb2470
|
@ -1070,14 +1070,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
|||
for i, param, param_id in self.params_in_ipg_bucket:
|
||||
|
||||
process_group = self.dp_process_group
|
||||
grad_reduc = self.get_gradient_for_reduction(param)
|
||||
#Averages gradients at parameter level if ipg has a moe param
|
||||
#Otherwise averaging is done at the entire buffer level at the end of the loop
|
||||
# MoE param have different groups
|
||||
|
||||
if self.ipg_bucket_has_moe_params:
|
||||
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
|
||||
param) else self.dp_process_group
|
||||
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))
|
||||
|
||||
partition_ids = self.param_to_partition_ids[i][param_id]
|
||||
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
|
||||
|
@ -1116,8 +1112,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
|||
curr_size += numel
|
||||
prev_id, prev_process_group = partition_id, process_group
|
||||
|
||||
if not self.ipg_bucket_has_moe_params:
|
||||
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
|
||||
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
|
||||
|
||||
buckets = {}
|
||||
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import deepspeed
|
||||
import pytest
|
||||
import gc
|
||||
import random
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
|
||||
import deepspeed.comm as dist
|
||||
|
@ -238,3 +239,114 @@ class TestTopkGate(DistributedTest):
|
|||
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
|
||||
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
|
||||
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
|
||||
|
||||
|
||||
class TestExpertWeightGradWithZero(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
|
||||
def test(self, zero_stage):
|
||||
|
||||
if not required_torch_version(min_version=1.8):
|
||||
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
|
||||
|
||||
def seed_everything(seed=11):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
get_accelerator().manual_seed(seed)
|
||||
get_accelerator().manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def get_state_dict_ep2(state_dict):
|
||||
"""
|
||||
convert state_dict from EP=1 to EP=2
|
||||
"""
|
||||
rank = int(deepspeed.comm.get_rank())
|
||||
ep_state_dict = dict()
|
||||
dst_sub_key = f"deepspeed_moe.experts.deepspeed_experts.0"
|
||||
src_sub_key = f"deepspeed_moe.experts.deepspeed_experts.{rank}"
|
||||
for moe_layer in ["moe_1", "moe_2"]:
|
||||
for mlp_in_moe in [0, 1]:
|
||||
dst_key = f"{moe_layer}.{dst_sub_key}.{mlp_in_moe}"
|
||||
src_key = f"{moe_layer}.{src_sub_key}.{mlp_in_moe}"
|
||||
ep_state_dict[f"{dst_key}.weight"] = state_dict[f"{src_key}.weight"].detach().clone()
|
||||
ep_state_dict[f"{dst_key}.bias"] = state_dict[f"{src_key}.bias"].detach().clone()
|
||||
|
||||
for key in state_dict.keys():
|
||||
if "deepspeed_moe.experts.deepspeed_experts" not in key:
|
||||
ep_state_dict[key] = state_dict[key].detach().clone()
|
||||
return ep_state_dict
|
||||
|
||||
def get_models(hidden_dim):
|
||||
model_ep1 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=1, use_rts=False)
|
||||
model_ep2 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=2, use_rts=False)
|
||||
|
||||
state_dict_ep1 = model_ep1.state_dict()
|
||||
state_dict_ep2 = get_state_dict_ep2(state_dict_ep1)
|
||||
model_ep2.load_state_dict(state_dict_ep2)
|
||||
|
||||
model_ep1, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep1)
|
||||
model_ep2, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep2)
|
||||
|
||||
return model_ep1, model_ep2
|
||||
|
||||
def extract_expert_grad(model, expert_id):
|
||||
|
||||
def _get_weight_bias(experts):
|
||||
return ([deepspeed.utils.safe_get_full_grad(expert[0].weight)
|
||||
for expert in experts][expert_id].detach().clone(),
|
||||
[deepspeed.utils.safe_get_full_grad(expert[0].bias)
|
||||
for expert in experts][expert_id].detach().clone(),
|
||||
[deepspeed.utils.safe_get_full_grad(expert[1].weight)
|
||||
for expert in experts][expert_id].detach().clone(),
|
||||
[deepspeed.utils.safe_get_full_grad(expert[1].bias)
|
||||
for expert in experts][expert_id].detach().clone())
|
||||
|
||||
return (*_get_weight_bias(model.moe_1.deepspeed_moe.experts.deepspeed_experts),
|
||||
*_get_weight_bias(model.moe_2.deepspeed_moe.experts.deepspeed_experts))
|
||||
|
||||
seed_everything()
|
||||
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.1,
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
}
|
||||
}
|
||||
|
||||
hidden_dim = 4
|
||||
total_samples = 2
|
||||
rank = deepspeed.comm.get_rank()
|
||||
model_ep1, model_ep2 = get_models(hidden_dim)
|
||||
|
||||
data_loader = sequence_dataloader(model=model_ep1,
|
||||
total_samples=total_samples,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model_ep1.device,
|
||||
dtype=torch.float32)
|
||||
expert_weight_grad_ep1 = []
|
||||
expert_weight_grad_ep2 = []
|
||||
for batch in data_loader:
|
||||
loss_ep1 = model_ep1(batch[0], batch[1])
|
||||
loss_ep2 = model_ep2(batch[0], batch[1])
|
||||
|
||||
model_ep1.backward(loss_ep1)
|
||||
model_ep2.backward(loss_ep2)
|
||||
|
||||
expert_weight_grad_ep1.extend(extract_expert_grad(model_ep1, rank))
|
||||
expert_weight_grad_ep2.extend(extract_expert_grad(model_ep2, 0))
|
||||
|
||||
model_ep1.step()
|
||||
model_ep2.step()
|
||||
|
||||
assert len(expert_weight_grad_ep1) == len(expert_weight_grad_ep2)
|
||||
for grad_from_ep1, grad_from_ep2 in zip(expert_weight_grad_ep1, expert_weight_grad_ep2):
|
||||
assert torch.allclose(grad_from_ep1, grad_from_ep2, atol=0, rtol=1e-4)
|
||||
|
|
|
@ -79,7 +79,7 @@ class Curriculum_SimpleModel(SimpleModel):
|
|||
|
||||
class SimpleMoEModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
|
||||
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False, use_rts=True):
|
||||
super(SimpleMoEModel, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
|
||||
expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
|
||||
|
@ -89,7 +89,8 @@ class SimpleMoEModel(torch.nn.Module):
|
|||
ep_size=ep_size,
|
||||
use_residual=use_residual,
|
||||
num_experts=num_experts,
|
||||
k=1)
|
||||
k=1,
|
||||
use_rts=use_rts)
|
||||
# interleaving MoE modules with dense to create an opportunity
|
||||
# for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket
|
||||
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
|
||||
|
@ -98,7 +99,8 @@ class SimpleMoEModel(torch.nn.Module):
|
|||
ep_size=ep_size,
|
||||
use_residual=use_residual,
|
||||
num_experts=num_experts,
|
||||
k=1)
|
||||
k=1,
|
||||
use_rts=use_rts)
|
||||
self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче