зеркало из https://github.com/microsoft/DeepSpeed.git
MoE inference + PR-MoE model support (#1705)
Co-authored-by: Reza Yazdani <reyazda@microsoft.com> Co-authored-by: Zhewei Yao <zheweiy@berkeley.edu> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
This commit is contained in:
Родитель
3293cf72a0
Коммит
e46d808a1b
|
@ -0,0 +1,8 @@
|
|||
include *.txt README.md
|
||||
recursive-include requirements *.txt
|
||||
|
||||
# this is for Windows only
|
||||
recursive-include deepspeed *.tr
|
||||
recursive-exclude deepspeed/ops/csrc *.cpp *.h *.cu *.cuh *.cc
|
||||
prune csrc
|
||||
prune op_builder
|
|
@ -1,5 +1,8 @@
|
|||
#include "custom_cuda_layers.h"
|
||||
|
||||
#define MAX_CAP 4
|
||||
#define MAX_SEQ 2048
|
||||
|
||||
inline __device__ float gelu(const float x)
|
||||
{
|
||||
const float sqrt_param = 0.79788456080286535587989211986876f;
|
||||
|
@ -168,7 +171,8 @@ __global__ void fused_bias_residual(float* input,
|
|||
const float* residual,
|
||||
const float* bias,
|
||||
int total_count,
|
||||
int intermediate_size)
|
||||
int intermediate_size,
|
||||
bool add_bias)
|
||||
{
|
||||
float4* input_cast = reinterpret_cast<float4*>(input);
|
||||
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
|
||||
|
@ -178,12 +182,18 @@ __global__ void fused_bias_residual(float* input,
|
|||
if (offset < total_count) {
|
||||
float4 data = input_cast[offset];
|
||||
float4 res_vec = residual_cast[offset];
|
||||
float4 bias_data = bias_cast[offset % intermediate_size];
|
||||
|
||||
data.x += (res_vec.x + bias_data.x);
|
||||
data.y += (res_vec.y + bias_data.y);
|
||||
data.z += (res_vec.z + bias_data.z);
|
||||
data.w += (res_vec.w + bias_data.w);
|
||||
if (add_bias) {
|
||||
float4 bias_data = bias_cast[offset % intermediate_size];
|
||||
data.x += (res_vec.x + bias_data.x);
|
||||
data.y += (res_vec.y + bias_data.y);
|
||||
data.z += (res_vec.z + bias_data.z);
|
||||
data.w += (res_vec.w + bias_data.w);
|
||||
} else {
|
||||
data.x += res_vec.x;
|
||||
data.y += res_vec.y;
|
||||
data.z += res_vec.z;
|
||||
data.w += res_vec.w;
|
||||
}
|
||||
|
||||
input_cast[offset] = data;
|
||||
}
|
||||
|
@ -193,7 +203,8 @@ __global__ void fused_bias_residual(__half* input,
|
|||
const __half* residual,
|
||||
const __half* bias,
|
||||
int total_count,
|
||||
int intermediate_size)
|
||||
int intermediate_size,
|
||||
bool add_bias)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
|
||||
|
@ -208,11 +219,8 @@ __global__ void fused_bias_residual(__half* input,
|
|||
float2 vals_vec = input_cast[offset];
|
||||
float2 res_vec = residual_cast[offset];
|
||||
|
||||
float2 bias_vec = bias_cast[offset % intermediate_size];
|
||||
|
||||
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
|
||||
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
|
||||
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
|
||||
|
||||
float2 low_data = __half22float2(vals_half[0]);
|
||||
float2 high_data = __half22float2(vals_half[1]);
|
||||
|
@ -220,13 +228,21 @@ __global__ void fused_bias_residual(__half* input,
|
|||
float2 low_res = __half22float2(res_half[0]);
|
||||
float2 high_res = __half22float2(res_half[1]);
|
||||
|
||||
float2 low_bias = __half22float2(bias_half[0]);
|
||||
float2 high_bias = __half22float2(bias_half[1]);
|
||||
|
||||
low_data.x += (low_res.x + low_bias.x);
|
||||
low_data.y += (low_res.y + low_bias.y);
|
||||
high_data.x += (high_res.x + high_bias.x);
|
||||
high_data.y += (high_res.y + high_bias.y);
|
||||
if (add_bias) {
|
||||
float2 bias_vec = bias_cast[offset % intermediate_size];
|
||||
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
|
||||
float2 low_bias = __half22float2(bias_half[0]);
|
||||
float2 high_bias = __half22float2(bias_half[1]);
|
||||
low_data.x += (low_res.x + low_bias.x);
|
||||
low_data.y += (low_res.y + low_bias.y);
|
||||
high_data.x += (high_res.x + high_bias.x);
|
||||
high_data.y += (high_res.y + high_bias.y);
|
||||
} else {
|
||||
low_data.x += low_res.x;
|
||||
low_data.y += low_res.y;
|
||||
high_data.x += high_res.x;
|
||||
high_data.y += high_res.y;
|
||||
}
|
||||
|
||||
vals_half[0] = __float22half2_rn(low_data);
|
||||
vals_half[1] = __float22half2_rn(high_data);
|
||||
|
@ -242,6 +258,7 @@ void launch_bias_residual(T* input,
|
|||
const T* bias,
|
||||
int batch,
|
||||
int intermediate_size,
|
||||
bool add_bias,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int total_count = batch * intermediate_size / 4;
|
||||
|
@ -249,21 +266,13 @@ void launch_bias_residual(T* input,
|
|||
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
|
||||
|
||||
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
|
||||
input, residual, bias, total_count, intermediate_size / 4);
|
||||
input, residual, bias, total_count, intermediate_size / 4, add_bias);
|
||||
}
|
||||
|
||||
template void launch_bias_residual<float>(float*,
|
||||
const float*,
|
||||
const float*,
|
||||
int,
|
||||
int,
|
||||
cudaStream_t);
|
||||
template void launch_bias_residual<__half>(__half*,
|
||||
const __half*,
|
||||
const __half*,
|
||||
int,
|
||||
int,
|
||||
cudaStream_t);
|
||||
template void
|
||||
launch_bias_residual<float>(float*, const float*, const float*, int, int, bool, cudaStream_t);
|
||||
template void
|
||||
launch_bias_residual<__half>(__half*, const __half*, const __half*, int, int, bool, cudaStream_t);
|
||||
|
||||
__global__ void gptj_residual_add(float* input,
|
||||
float* output,
|
||||
|
@ -368,3 +377,95 @@ template void
|
|||
launch_gptj_residual_add<float>(float*, float*, float*, float*, int, int, cudaStream_t);
|
||||
template void
|
||||
launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t);
|
||||
|
||||
__global__ void moe_res_matmul(float* residual,
|
||||
float* coef,
|
||||
float* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim)
|
||||
{
|
||||
unsigned tid = threadIdx.x;
|
||||
float4* residual_cast = reinterpret_cast<float4*>(residual);
|
||||
float4* coef_cast = reinterpret_cast<float4*>(coef);
|
||||
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
|
||||
|
||||
residual_cast += blockIdx.x * hidden_dim;
|
||||
mlp_out_cast += blockIdx.x * hidden_dim;
|
||||
|
||||
float4* coef_cast2 = coef_cast + hidden_dim;
|
||||
|
||||
while (tid < hidden_dim) {
|
||||
float4 res = residual_cast[tid];
|
||||
float4 mlp = mlp_out_cast[tid];
|
||||
float4 coef1 = coef_cast[tid];
|
||||
float4 coef2 = coef_cast2[tid];
|
||||
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
|
||||
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
|
||||
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
|
||||
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
|
||||
mlp_out_cast[tid] = mlp;
|
||||
tid += blockDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void moe_res_matmul(__half* residual,
|
||||
__half* coef,
|
||||
__half* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim)
|
||||
{
|
||||
unsigned tid = threadIdx.x;
|
||||
|
||||
float2* residual_cast = reinterpret_cast<float2*>(residual);
|
||||
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
|
||||
float2* coef_cast = reinterpret_cast<float2*>(coef);
|
||||
float2* coef_cast2 = coef_cast + hidden_dim;
|
||||
|
||||
residual_cast += blockIdx.x * hidden_dim;
|
||||
mlp_out_cast += blockIdx.x * hidden_dim;
|
||||
|
||||
while (tid < hidden_dim) {
|
||||
float2 res = residual_cast[tid];
|
||||
float2 coef1 = coef_cast[tid];
|
||||
float2 coef2 = coef_cast[tid];
|
||||
float2 data = mlp_out_cast[tid];
|
||||
__half* data_h = reinterpret_cast<__half*>(&data);
|
||||
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
|
||||
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
|
||||
__half* res_h = reinterpret_cast<__half*>(&res);
|
||||
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
|
||||
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
|
||||
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
|
||||
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
|
||||
|
||||
mlp_out_cast[tid] = data;
|
||||
tid += blockDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_moe_res_matmul(T* residual,
|
||||
T* coef,
|
||||
T* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
dim3 grid_dim(seq_len);
|
||||
dim3 block_dim(1024);
|
||||
moe_res_matmul<<<grid_dim, block_dim, 0, stream>>>(
|
||||
residual, coef, mlp_out, seq_len, hidden_dim / 4);
|
||||
}
|
||||
|
||||
template void launch_moe_res_matmul(float* residual,
|
||||
float* coef,
|
||||
float* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim,
|
||||
cudaStream_t stream);
|
||||
template void launch_moe_res_matmul(__half* residual,
|
||||
__half* coef,
|
||||
__half* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
@ -9,19 +8,30 @@
|
|||
|
||||
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
|
||||
|
||||
#define MAX_OUT_TOKES 10
|
||||
|
||||
template <typename T>
|
||||
at::Tensor ds_softmax(at::Tensor& attn_scores,
|
||||
at::Tensor& attn_mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size)
|
||||
int window_size,
|
||||
bool async_op)
|
||||
{
|
||||
auto attn_scores_c = attn_scores.contiguous();
|
||||
int bsz = attn_scores_c.size(0);
|
||||
int seq_len = attn_scores_c.size(2);
|
||||
int soft_len = attn_scores_c.size(3);
|
||||
int heads = attn_scores_c.size(1);
|
||||
|
||||
int seq_len = attn_scores_c.size(1);
|
||||
int len = attn_scores_c.sizes().size();
|
||||
if (len > 3) seq_len = attn_scores_c.size(2);
|
||||
|
||||
int soft_len = attn_scores_c.size(2);
|
||||
if (len > 3) soft_len = attn_scores_c.size(3);
|
||||
|
||||
int heads = 1;
|
||||
if (len > 3) heads = attn_scores_c.size(1);
|
||||
|
||||
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
|
||||
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
|
||||
triangular,
|
||||
|
@ -33,11 +43,57 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
|
|||
seq_len,
|
||||
soft_len,
|
||||
1.0,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
Context::Instance().GetCurrentStream(async_op));
|
||||
|
||||
return attn_scores_c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void allocate_workspace(size_t hidden_dim,
|
||||
size_t max_seq_len,
|
||||
size_t batch_size,
|
||||
size_t head_size = 128)
|
||||
{
|
||||
size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len);
|
||||
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
|
||||
{
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(Q.options().dtype())
|
||||
.layout(at::kStrided)
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
float alpha = 1;
|
||||
float gemm_beta = 0.0;
|
||||
|
||||
if (!workspace) {
|
||||
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0));
|
||||
workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
}
|
||||
|
||||
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
|
||||
unsigned m = W.size(1);
|
||||
unsigned n = Q.size(1) * Q.size(2);
|
||||
unsigned k = Q.size(0);
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_T,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha,
|
||||
&gemm_beta,
|
||||
(T*)W.data_ptr(),
|
||||
(T*)Q.data_ptr(),
|
||||
(T*)O.data_ptr(),
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
return O;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void attention_unfused(at::Tensor& prev_key_cont,
|
||||
at::Tensor& query_cont,
|
||||
|
@ -61,7 +117,7 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
.requires_grad(false);
|
||||
float alpha = norm_factor;
|
||||
float gemm_beta = 0.0;
|
||||
auto attn_score = at::zeros({bsz, heads, seq_len, soft_len}, options);
|
||||
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
|
||||
int k = prev_value_cont.size(2) / heads;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
|
@ -80,8 +136,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
seq_len * soft_len,
|
||||
bsz * heads,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
attn_score =
|
||||
ds_softmax<T>(attn_score, attn_mask, triangular, recompute, local_attention, window_size);
|
||||
attn_score = ds_softmax<T>(
|
||||
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
|
||||
alpha = 1.0;
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
k,
|
||||
|
@ -177,12 +233,12 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
|
|||
auto residual_cont = residual.contiguous();
|
||||
|
||||
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||
|
||||
launch_bias_residual((T*)input_cont.data_ptr(),
|
||||
(T*)residual_cont.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
bsz,
|
||||
input_cont.size(2),
|
||||
(bias.size(0) > 1),
|
||||
Context::Instance().GetCurrentStream());
|
||||
return input_cont;
|
||||
}
|
||||
|
@ -409,7 +465,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight)
|
||||
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
|
||||
{
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
|
@ -422,7 +478,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight)
|
|||
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream(async_op));
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
|
@ -435,7 +492,6 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight)
|
|||
(T*)input_cont.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
@ -503,7 +559,6 @@ void mlp_unfused_cublas(at::Tensor& output,
|
|||
(T*)inp_norm.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
launch_bias_gelu((T*)output.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
|
@ -601,10 +656,11 @@ template <typename T>
|
|||
at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& weight_out)
|
||||
at::Tensor& weight_out,
|
||||
const float epsilon,
|
||||
bool preLayerNorm,
|
||||
bool async_op)
|
||||
{
|
||||
// cudaStreamWaitEvent(
|
||||
// Context::Instance().GetCurrentStream(true), Context::Instance().GetCompEvent(1), 0);
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input_cont.options().dtype())
|
||||
|
@ -717,6 +773,59 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
|
|||
return {query_cont, key_cont};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
const float epsilon,
|
||||
at::Tensor& q_scale,
|
||||
int groups,
|
||||
bool preLayerNorm)
|
||||
{
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input_cont.options().dtype())
|
||||
.layout(at::kStrided)
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
|
||||
|
||||
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||
|
||||
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
|
||||
launch_bias_gelu((T*)output.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
|
||||
{
|
||||
int M = moe_res.size(0) * moe_res.size(1);
|
||||
int N = moe_res.size(2);
|
||||
Context::Instance().SynchComm();
|
||||
if (moe_res.scalar_type() == at::kFloat) {
|
||||
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
|
||||
(float*)coef.data_ptr(),
|
||||
(float*)output.data_ptr(),
|
||||
M,
|
||||
N,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else {
|
||||
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
|
||||
(__half*)coef.data_ptr(),
|
||||
(__half*)output.data_ptr(),
|
||||
M,
|
||||
N,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
|
||||
|
@ -756,4 +865,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
m.def("gptj_residual_add", &gptj_residual_add, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
m.def("einsum_sec_sm_ecm_fp32",
|
||||
&einsum_sec_sm_ecm<float>,
|
||||
"DeepSpeed vector-MM with fp32 (CUDA)");
|
||||
|
||||
m.def("einsum_sec_sm_ecm_fp16",
|
||||
&einsum_sec_sm_ecm<__half>,
|
||||
"DeepSpeed vector-MM with fp16 (CUDA)");
|
||||
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ public:
|
|||
cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
|
||||
cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
cudaEventCreate(&_comp_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
cudaEventCreate(&_comm_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
virtual ~Context()
|
||||
|
@ -60,6 +62,8 @@ public:
|
|||
cudaFree(_workspace);
|
||||
cudaEventDestroy(_comp1_event);
|
||||
cudaEventDestroy(_comp2_event);
|
||||
cudaEventDestroy(_comp_event);
|
||||
cudaEventDestroy(_comm_event);
|
||||
}
|
||||
|
||||
static Context& Instance()
|
||||
|
@ -81,10 +85,35 @@ public:
|
|||
_workSpaceSize = size;
|
||||
}
|
||||
|
||||
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
|
||||
|
||||
size_t get_workspace_size() const { return _workSpaceSize; }
|
||||
void* GetWorkSpace() { return _workspace; }
|
||||
|
||||
inline unsigned new_token(unsigned layer_id)
|
||||
{
|
||||
if (layer_id == 0) _token_length++;
|
||||
return _token_length;
|
||||
}
|
||||
|
||||
inline void reset_tokens(unsigned initial_tokens = 0)
|
||||
{
|
||||
_num_tokens = initial_tokens;
|
||||
} //_token_length = 0; }
|
||||
|
||||
inline unsigned current_tokens() const { return _num_tokens; }
|
||||
|
||||
inline void advance_tokens() { _num_tokens++; }
|
||||
|
||||
curandGenerator_t& GetRandGenerator() { return _gen; }
|
||||
|
||||
cudaStream_t GetCommStream(bool async_op = false)
|
||||
{
|
||||
if (!_comm_stream)
|
||||
_comm_stream = async_op ? at::cuda::getStreamFromPool(true)
|
||||
: at::cuda::getCurrentCUDAStream();
|
||||
return _comm_stream;
|
||||
}
|
||||
cudaStream_t GetCurrentStream(bool other_stream = false)
|
||||
{
|
||||
// get current pytorch stream.
|
||||
|
@ -96,8 +125,6 @@ public:
|
|||
return stream;
|
||||
}
|
||||
|
||||
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
|
||||
|
||||
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
|
||||
|
||||
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
|
||||
|
@ -111,9 +138,24 @@ public:
|
|||
|
||||
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
|
||||
|
||||
inline void SynchComp()
|
||||
{
|
||||
cudaEventRecord(_comp_event, _comp_stream);
|
||||
cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
|
||||
}
|
||||
inline void SynchComm()
|
||||
{
|
||||
cudaEventRecord(_comm_event, _comm_stream);
|
||||
cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
curandGenerator_t _gen;
|
||||
cublasHandle_t _cublasHandle;
|
||||
|
||||
cudaEvent_t _comp_event;
|
||||
cudaEvent_t _comm_event;
|
||||
|
||||
void* _workspace;
|
||||
uint64_t _seed;
|
||||
uint64_t _curr_offset;
|
||||
|
@ -124,5 +166,12 @@ private:
|
|||
|
||||
cudaStream_t _stream;
|
||||
|
||||
unsigned _token_length;
|
||||
unsigned _num_tokens;
|
||||
std::vector<std::array<int, 3>> _gemm_algos;
|
||||
|
||||
cudaStream_t _comp_stream;
|
||||
cudaStream_t _comm_stream;
|
||||
|
||||
std::unordered_map<int, int> _world_sizes;
|
||||
};
|
||||
|
|
|
@ -43,6 +43,7 @@ void launch_bias_residual(T* input,
|
|||
const T* bias,
|
||||
int size,
|
||||
int intermediate_size,
|
||||
bool add_bias,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
|
@ -96,3 +97,11 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
|
|||
unsigned num_heads,
|
||||
unsigned batch,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_moe_res_matmul(T* residual,
|
||||
T* coef,
|
||||
T* mlp_out,
|
||||
int seq_len,
|
||||
int hidden_dim,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -218,21 +218,30 @@ def add_config_arguments(parser):
|
|||
|
||||
|
||||
def init_inference(model,
|
||||
triangular_masking=True,
|
||||
mp_size=1,
|
||||
mpu=None,
|
||||
ep_group=None,
|
||||
expert_mp_group=None,
|
||||
checkpoint=None,
|
||||
module_key='module',
|
||||
dtype=None,
|
||||
injection_policy=None,
|
||||
replace_method='auto',
|
||||
quantization_setting=None,
|
||||
replace_with_kernel_inject=False,
|
||||
return_tuple=True):
|
||||
return_tuple=True,
|
||||
ep_size=1,
|
||||
moe=False,
|
||||
moe_experts=1,
|
||||
moe_type='standard'):
|
||||
"""Initialize the DeepSpeed InferenceEngine.
|
||||
|
||||
Arguments:
|
||||
model: Required: nn.module class before apply any wrappers
|
||||
|
||||
triangular_masking: Required: this shows the type of masking for attention scores in transformer layer
|
||||
note that the masking is application specific.
|
||||
|
||||
mp_size: Optional: Desired model parallel size, default is 1 meaning no
|
||||
model parallelism.
|
||||
|
||||
|
@ -272,14 +281,21 @@ def init_inference(model,
|
|||
raise NotImplementedError("pipeline module support is not implemented yet")
|
||||
else:
|
||||
engine = InferenceEngine(model,
|
||||
triangular_masking,
|
||||
mp_size,
|
||||
ep_size,
|
||||
mpu,
|
||||
ep_group,
|
||||
expert_mp_group,
|
||||
checkpoint,
|
||||
dtype,
|
||||
injection_policy,
|
||||
return_tuple,
|
||||
replace_method,
|
||||
quantization_setting,
|
||||
replace_with_kernel_inject)
|
||||
replace_with_kernel_inject,
|
||||
moe,
|
||||
moe_experts,
|
||||
moe_type)
|
||||
|
||||
return engine
|
||||
|
|
|
@ -11,22 +11,34 @@ from ..module_inject.replace_module import replace_transformer_layer
|
|||
from ..utils import logger, init_distributed
|
||||
|
||||
from ..pipe import PipelineModule
|
||||
from ..moe.utils import has_moe_layers
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class InferenceEngine(Module):
|
||||
inference_mp_group = None
|
||||
inference_ep_group = None
|
||||
expert_mp_group = None
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
triangular_masking=True,
|
||||
mp_size=1,
|
||||
ep_size=1,
|
||||
mpu=None,
|
||||
ep_group=None,
|
||||
expert_mp_group=None,
|
||||
checkpoint=None,
|
||||
dtype=None,
|
||||
injection_dict=None,
|
||||
return_tuple=True,
|
||||
replace_method='auto',
|
||||
quantization_setting=None,
|
||||
replace_with_kernel_inject=False):
|
||||
replace_with_kernel_inject=False,
|
||||
moe=False,
|
||||
moe_experts=1,
|
||||
moe_type='standard'):
|
||||
"""
|
||||
Args:
|
||||
model: torch.nn.Module
|
||||
|
@ -59,6 +71,10 @@ class InferenceEngine(Module):
|
|||
self.replace_method = replace_method
|
||||
self.quantize_merge_count = 1
|
||||
self.quantization_scales = None
|
||||
self.triangular_masking = triangular_masking
|
||||
self.ep_size = ep_size
|
||||
self.ep_group = ep_group
|
||||
self.expert_mp_group = expert_mp_group
|
||||
|
||||
self._init_quantization_setting(quantization_setting)
|
||||
|
||||
|
@ -72,20 +88,29 @@ class InferenceEngine(Module):
|
|||
if self.mpu:
|
||||
self.mp_world_size = dist.get_world_size(
|
||||
group=self.mpu.get_model_parallel_group())
|
||||
self.mp_group = self.mpu.get_model_parallel_group()
|
||||
self.mp_group = mpu.get_model_parallel_group()
|
||||
elif self.mp_world_size > 1:
|
||||
self._create_model_parallel_group()
|
||||
# apply injection policy
|
||||
if self.injection_dict is not None:
|
||||
|
||||
moe, _ = has_moe_layers(self.module)
|
||||
if moe:
|
||||
self._create_ep_parallel_group(moe_experts)
|
||||
if self.injection_dict:
|
||||
for client_module, injection_policy in self.injection_dict.items():
|
||||
self._apply_injection_policy(client_module,
|
||||
injection_policy,
|
||||
return_tuple,
|
||||
replace_with_kernel_inject)
|
||||
replace_with_kernel_inject,
|
||||
moe,
|
||||
moe_experts,
|
||||
moe_type)
|
||||
elif replace_method == 'auto':
|
||||
self._apply_injection_policy(
|
||||
return_tuple=return_tuple,
|
||||
replace_with_kernel_inject=replace_with_kernel_inject)
|
||||
replace_with_kernel_inject=replace_with_kernel_inject,
|
||||
moe=moe,
|
||||
moe_experts=moe_experts,
|
||||
moe_type=moe_type)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
logger.info(f"Place model to device: {device}")
|
||||
|
@ -112,9 +137,40 @@ class InferenceEngine(Module):
|
|||
ranks = [i for i in range(self.mp_world_size)]
|
||||
self.mp_group = dist.new_group(ranks)
|
||||
InferenceEngine.inference_mp_group = self.mp_group
|
||||
|
||||
else:
|
||||
self.mp_group = InferenceEngine.inference_mp_group
|
||||
|
||||
def _create_ep_parallel_group(self, moe_experts):
|
||||
# Call the init process
|
||||
self.ep_group = {}
|
||||
self.expert_mp_group = {}
|
||||
moe_experts = moe_experts if type(moe_experts) is list else [moe_experts]
|
||||
for e in moe_experts:
|
||||
self.ep_group.update({e: None})
|
||||
self.expert_mp_group.update({e: None})
|
||||
for moe_ep_size in self.ep_group.keys():
|
||||
num_ep_groups = dist.get_world_size() // moe_ep_size
|
||||
for i in range(num_ep_groups):
|
||||
ep_cnt = i * moe_ep_size
|
||||
size = dist.get_world_size(
|
||||
) if moe_ep_size > dist.get_world_size() else moe_ep_size
|
||||
ranks = list(range(ep_cnt, ep_cnt + size))
|
||||
_ep_group = dist.new_group(ranks)
|
||||
if dist.get_rank() in ranks:
|
||||
self.ep_group.update({moe_ep_size: _ep_group})
|
||||
|
||||
if dist.get_world_size() > moe_ep_size:
|
||||
num_expert_mp_groups = dist.get_world_size() // num_ep_groups
|
||||
expert_mp_size = dist.get_world_size() // moe_ep_size
|
||||
for i in range(num_expert_mp_groups):
|
||||
expert_mp_comm_ranks = [
|
||||
i + nr * moe_ep_size for nr in range(expert_mp_size)
|
||||
]
|
||||
_expert_mp_group = dist.new_group(expert_mp_comm_ranks)
|
||||
if dist.get_rank() in expert_mp_comm_ranks:
|
||||
self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
|
||||
|
||||
def _init_quantization_setting(self, quantization_setting):
|
||||
self.quantize_bits = 8
|
||||
self.mlp_extra_grouping = False
|
||||
|
@ -156,13 +212,19 @@ class InferenceEngine(Module):
|
|||
client_module=None,
|
||||
injection_policy=None,
|
||||
return_tuple=True,
|
||||
replace_with_kernel_inject=False):
|
||||
replace_with_kernel_inject=False,
|
||||
moe=False,
|
||||
moe_experts=1,
|
||||
moe_type='standard'):
|
||||
|
||||
replace_transformer_layer(client_module,
|
||||
self.module,
|
||||
triangular_masking=self.triangular_masking,
|
||||
policy=injection_policy,
|
||||
mp_size=self.mp_world_size,
|
||||
mp_group=self.mp_group,
|
||||
ep_group=self.ep_group,
|
||||
expert_mp_group=self.expert_mp_group,
|
||||
config=self.config,
|
||||
fp16=(self.dtype == torch.half),
|
||||
training=False,
|
||||
|
@ -172,7 +234,10 @@ class InferenceEngine(Module):
|
|||
self.quantize_merge_count,
|
||||
self.mlp_extra_grouping,
|
||||
self.quantize_groups),
|
||||
replace_with_kernel_inject=replace_with_kernel_inject)
|
||||
replace_with_kernel_inject=replace_with_kernel_inject,
|
||||
moe=moe,
|
||||
moe_experts=moe_experts,
|
||||
moe_type=moe_type)
|
||||
|
||||
def _load_checkpoint(self, load_dir, load_module_strict=True):
|
||||
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
|
||||
|
@ -183,7 +248,9 @@ class InferenceEngine(Module):
|
|||
'pipeline parallelism is currently not supported in inference.')
|
||||
|
||||
mp_rank = 0 if self.mp_group is None else dist.get_rank(group=self.mp_group)
|
||||
|
||||
print(
|
||||
f'self.mp_world_size: {self.mp_world_size}, mp_rank: {mp_rank}, is_pipe_parallel: {is_pipe_parallel}, (self.dtype is torch.int8): {(self.dtype is torch.int8)}, quantize_groups:{self.quantize_groups}, mlp_extra_grouping: {self.mlp_extra_grouping}'
|
||||
)
|
||||
load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size,
|
||||
mp_rank,
|
||||
is_pipe_parallel=is_pipe_parallel,
|
||||
|
@ -197,8 +264,17 @@ class InferenceEngine(Module):
|
|||
# Pipeline parallelism uses this to load its own checkpoint files.
|
||||
self._curr_ckpt_path = load_dir
|
||||
|
||||
self.module.load_state_dict(state_dict=checkpoint['model'],
|
||||
strict=load_module_strict)
|
||||
self.module.load_state_dict(
|
||||
state_dict=checkpoint[self._choose_module_key(checkpoint)],
|
||||
strict=load_module_strict)
|
||||
|
||||
def _choose_module_key(self, sd):
|
||||
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
|
||||
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
|
||||
if 'module' in sd:
|
||||
return 'module'
|
||||
elif 'model' in sd:
|
||||
return 'model'
|
||||
|
||||
def _convert_to_dtype(self):
|
||||
if self.dtype is torch.int8 and self.quantization_scales is None:
|
||||
|
|
|
@ -40,7 +40,7 @@ class LinearLayer(nn.Module):
|
|||
|
||||
class ReplaceWithTensorSlicing:
|
||||
def __init__(self, mp_group=None):
|
||||
if (torch.distributed.is_initialized() and mp_group is not None):
|
||||
if mp_group is not None:
|
||||
self.gpu_index = torch.distributed.get_rank(group=mp_group)
|
||||
else:
|
||||
self.gpu_index = 0
|
||||
|
@ -57,7 +57,7 @@ class ReplaceWithTensorSlicing:
|
|||
src_shape = src.shape
|
||||
dst_shape = dst.shape
|
||||
|
||||
src_split = torch.split(src, src.shape[-1] // 3, dim=-1)
|
||||
src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1)
|
||||
|
||||
if (len(src_shape) == 2 and len(dst_shape) == 2):
|
||||
if src_shape[1] == dst_shape[1]:
|
||||
|
@ -71,7 +71,8 @@ class ReplaceWithTensorSlicing:
|
|||
torch.cat([qkv_s[i] for qkv_s in qkv_split],
|
||||
axis=1) for i in range(len(qkv_split[0]))
|
||||
]
|
||||
dst = weight_split[self.gpu_index].to(torch.cuda.current_device())
|
||||
dst.data.copy(weight_split[self.gpu_index].to(
|
||||
torch.cuda.current_device()).contiguous())
|
||||
else:
|
||||
if src_shape[0] == dst_shape[0]:
|
||||
return src
|
||||
|
@ -82,9 +83,10 @@ class ReplaceWithTensorSlicing:
|
|||
torch.cat([qkv_s[i] for qkv_s in qkv_split],
|
||||
axis=0) for i in range(len(qkv_split[0]))
|
||||
]
|
||||
dst = bias_split[self.gpu_index].to(torch.cuda.current_device())
|
||||
dst.data.copy(bias_split[self.gpu_index].to(
|
||||
torch.cuda.current_device()).contiguous())
|
||||
|
||||
return dst.contiguous()
|
||||
return dst
|
||||
|
||||
def copy(self, dst, src):
|
||||
if src is None:
|
||||
|
@ -103,17 +105,19 @@ class ReplaceWithTensorSlicing:
|
|||
weight_split = torch.split(src, dst_shape[0])
|
||||
else:
|
||||
self.merge_assert(src_shape[1], dst_shape[1])
|
||||
weight_split = torch.split(src, dst_shape[1], dim=1)
|
||||
weight_split = torch.split(src.data, dst_shape[1], dim=1)
|
||||
|
||||
dst = weight_split[self.gpu_index].to(torch.cuda.current_device())
|
||||
dst.data.copy_(weight_split[self.gpu_index].to(
|
||||
torch.cuda.current_device()).contiguous())
|
||||
else:
|
||||
if src_shape[0] == dst_shape[0]:
|
||||
return src
|
||||
|
||||
bias_split = torch.split(src, dst_shape[-1])
|
||||
dst = bias_split[self.gpu_index].to(torch.cuda.current_device())
|
||||
bias_split = torch.split(src.data, dst_shape[-1])
|
||||
dst.data.copy_(bias_split[self.gpu_index].to(
|
||||
torch.cuda.current_device()).contiguous())
|
||||
|
||||
return dst.contiguous()
|
||||
return dst
|
||||
|
||||
|
||||
def replace_transformer_layer(orig_layer_impl,
|
||||
|
@ -126,6 +130,8 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
num_attention_heads=-1,
|
||||
mp_size=1,
|
||||
mp_group=None,
|
||||
ep_group=None,
|
||||
expert_mp_group=None,
|
||||
preln=True,
|
||||
fp16=True,
|
||||
local_rank=-1,
|
||||
|
@ -133,9 +139,13 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
training=True,
|
||||
quantize=False,
|
||||
quantize_settings=None,
|
||||
triangular_masking=False,
|
||||
return_tuple=True,
|
||||
replace_with_kernel_inject=False,
|
||||
linear_layer_setting=None):
|
||||
linear_layer_setting=None,
|
||||
moe=False,
|
||||
moe_experts=1,
|
||||
moe_type='standard'):
|
||||
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
|
||||
Arguments:
|
||||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
|
||||
|
@ -170,7 +180,12 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
Returns:
|
||||
Updated nn.module with replaced transformer layers
|
||||
"""
|
||||
def replace_with_policy(child, policy_cls, inference=False, preln=True, layer_id=0):
|
||||
def replace_with_policy(child,
|
||||
policy_cls,
|
||||
triangular_masking,
|
||||
inference=False,
|
||||
preln=True,
|
||||
layer_id=0):
|
||||
preln = False if policy_cls is HFBertLayerPolicy else preln
|
||||
if policy_cls is HFBertLayerPolicy:
|
||||
policy = policy_cls(child, inference=inference, preln=preln)
|
||||
|
@ -182,87 +197,143 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
assert num_attention_heads % mp_size == 0,\
|
||||
"To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
|
||||
"This is because the attention computation is partitioned evenly among the parallel GPUs."
|
||||
from deepspeed.moe.utils import has_moe_layers
|
||||
moe, num_experts = has_moe_layers(child)
|
||||
|
||||
attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention()
|
||||
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
|
||||
attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
|
||||
if not moe or moe_type == 'standard':
|
||||
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
|
||||
else:
|
||||
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
|
||||
_res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)
|
||||
|
||||
attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
|
||||
if quantize:
|
||||
if policy_cls is not HFBertLayerPolicy:
|
||||
qkvw = qkvw.to(torch.int8)
|
||||
dense_w = dense_w.to(torch.int8)
|
||||
_h4h_w = _h4h_w.to(torch.int8)
|
||||
_4hh_w = _4hh_w.to(torch.int8)
|
||||
_h4h_w = [moe_w1.to(torch.int8)
|
||||
for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
|
||||
_4hh_w = [moe_w1.to(torch.int8)
|
||||
for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
|
||||
elif fp16:
|
||||
qkvw = qkvw.half()
|
||||
dense_w = dense_w.half()
|
||||
_h4h_w = _h4h_w.half()
|
||||
_4hh_w = _4hh_w.half()
|
||||
|
||||
_h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
|
||||
_4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
|
||||
if quantize or fp16:
|
||||
qkvb = qkvb if qkvb is None else qkvb.half()
|
||||
dense_b = dense_b if dense_b is None else dense_b.half()
|
||||
_h4h_b = _h4h_b.half()
|
||||
_4hh_b = _4hh_b.half()
|
||||
attn_nw = attn_nw if dense_b is None else attn_nw.half()
|
||||
attn_nb = attn_nb if dense_b is None else attn_nb.half()
|
||||
_h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
|
||||
_4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
|
||||
attn_nw = attn_nw if attn_nw is None else attn_nw.half()
|
||||
attn_nb = attn_nb if attn_nb is None else attn_nb.half()
|
||||
input_nw = input_nw.half()
|
||||
input_nb = input_nb.half()
|
||||
|
||||
if moe and moe_type == 'residual' and fp16:
|
||||
_res_h4h_b = _res_h4h_b.half()
|
||||
_res_4hh_b = _res_4hh_b.half()
|
||||
_res_h4h_w = _res_h4h_w.half()
|
||||
_res_4hh_w = _res_4hh_w.half()
|
||||
_res_coef = _res_coef.half()
|
||||
|
||||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
|
||||
#expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)
|
||||
|
||||
if inference:
|
||||
transformer_config = transformer_inference.DeepSpeedInferenceConfig(
|
||||
hidden_size=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
layer_norm_eps=config.layer_norm_eps if hasattr(
|
||||
config,
|
||||
'layer_norm_eps') else
|
||||
(config.layer_norm_epsilon if hasattr(config,
|
||||
'layer_norm_epsilon') else 1e-12),
|
||||
fp16=fp16,
|
||||
pre_layer_norm=preln,
|
||||
mp_size=mp_size,
|
||||
q_int8=quantize,
|
||||
return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
|
||||
triangular_masking=(policy_cls is not HFBertLayerPolicy),
|
||||
local_attention=((config.attention_layers[layer_id] == "local")
|
||||
if hasattr(config,
|
||||
'attention_layers') else False),
|
||||
window_size=(config.window_size if hasattr(config,
|
||||
'window_size') else 1),
|
||||
rotary_dim=(config.rotary_dim if hasattr(config,
|
||||
'rotary_dim') else -1),
|
||||
mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy))
|
||||
if moe:
|
||||
ep_world_size = torch.distributed.get_world_size()
|
||||
local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size
|
||||
|
||||
transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
|
||||
hidden_size=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
layer_norm_eps=config.layer_norm_eps if hasattr(
|
||||
config,
|
||||
'layer_norm_eps') else 1e-12,
|
||||
fp16=fp16,
|
||||
pre_layer_norm=preln,
|
||||
mp_size=mp_size,
|
||||
q_int8=quantize,
|
||||
moe_experts=local_ep_size,
|
||||
global_experts=num_experts,
|
||||
mlp_type=moe_type)
|
||||
else:
|
||||
transformer_config = transformer_inference.DeepSpeedInferenceConfig(
|
||||
hidden_size=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
layer_norm_eps=config.layer_norm_eps if hasattr(
|
||||
config,
|
||||
'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(
|
||||
config,
|
||||
'layer_norm_epsilon') else 1e-12),
|
||||
fp16=fp16,
|
||||
pre_layer_norm=preln,
|
||||
mp_size=mp_size,
|
||||
q_int8=quantize,
|
||||
return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
|
||||
triangular_masking=(policy_cls is not HFBertLayerPolicy),
|
||||
local_attention=((config.attention_layers[layer_id] == "local")
|
||||
if hasattr(config,
|
||||
'attention_layers') else False),
|
||||
window_size=(config.window_size if hasattr(config,
|
||||
'window_size') else 1),
|
||||
rotary_dim=(config.rotary_dim if hasattr(config,
|
||||
'rotary_dim') else -1),
|
||||
mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy))
|
||||
|
||||
if quantize and quantize_settings is not None:
|
||||
(quantization_scales,
|
||||
merge_count,
|
||||
mlp_extra_grouping,
|
||||
quantize_groups) = quantize_settings
|
||||
new_module = transformer_inference.DeepSpeedTransformerInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
quantize_scales=quantization_scales[layer_id],
|
||||
quantize_groups=quantize_groups,
|
||||
merge_count=merge_count,
|
||||
mlp_extra_grouping=mlp_extra_grouping,
|
||||
qkv_merging=(policy_cls is HFBertLayerPolicy))
|
||||
if moe:
|
||||
new_module = transformer_inference.DeepSpeedMoEInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
ep_group=ep_group[num_experts],
|
||||
expert_mp_group=expert_mp_group[num_experts],
|
||||
quantize_scales=quantization_scales[layer_id],
|
||||
quantize_groups=quantize_groups,
|
||||
merge_count=merge_count,
|
||||
mlp_extra_grouping=mlp_extra_grouping,
|
||||
qkv_merging=(policy_cls is HFBertLayerPolicy))
|
||||
|
||||
else:
|
||||
new_module = transformer_inference.DeepSpeedTransformerInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
quantize_scales=quantization_scales[layer_id],
|
||||
quantize_groups=quantize_groups,
|
||||
merge_count=merge_count,
|
||||
mlp_extra_grouping=mlp_extra_grouping,
|
||||
qkv_merging=(policy_cls is HFBertLayerPolicy))
|
||||
|
||||
if quantize and qkvw.dtype != torch.int8:
|
||||
quantize_bits = 8
|
||||
quantizer = WeightQuantization()
|
||||
if policy_cls is HFBertLayerPolicy:
|
||||
data_quantized, _ = quantizer.quantize_data(qkvw, quantize_bits, quantize_groups * 3)
|
||||
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
|
||||
else:
|
||||
data_quantized, _ = quantizer.quantize_data(qkvw, quantize_bits, quantize_groups)
|
||||
qkvw.copy_(data_quantized)
|
||||
qkvw = qkvw.to(torch.int8)
|
||||
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
|
||||
qkvw.data.copy_(data_quantized)
|
||||
qkvw.data = qkvw.data.to(torch.int8)
|
||||
else:
|
||||
new_module = transformer_inference.DeepSpeedTransformerInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
|
||||
if moe:
|
||||
new_module = transformer_inference.DeepSpeedMoEInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
ep_group=ep_group[num_experts],
|
||||
expert_mp_group=expert_mp_group[num_experts],
|
||||
)
|
||||
|
||||
else:
|
||||
new_module = transformer_inference.DeepSpeedTransformerInference(
|
||||
transformer_config,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
new_module.config.scale_attention = scale_attention
|
||||
|
||||
# we want the weights in [input, output] shape
|
||||
|
@ -274,31 +345,69 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
return data
|
||||
|
||||
if attn_linear_layer:
|
||||
qkvw = transpose(qkvw.data)
|
||||
dense_w = transpose(dense_w)
|
||||
qkvw.data = transpose(qkvw.data)
|
||||
dense_w.data = transpose(dense_w.data)
|
||||
|
||||
if mlp_linear_layer:
|
||||
_h4h_w = transpose(_h4h_w)
|
||||
_4hh_w = transpose(_4hh_w)
|
||||
_h4h_w = [transpose(moe_w1.data)
|
||||
for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
|
||||
_4hh_w = [transpose(moe_w1.data)
|
||||
for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)
|
||||
|
||||
if moe and moe_type == 'residual':
|
||||
_res_h4h_w.data = transpose(_res_h4h_w.data)
|
||||
_res_4hh_w.data = transpose(_res_4hh_w.data)
|
||||
_res_coef.data = transpose(_res_coef.data)
|
||||
|
||||
attn_block = new_module.attention
|
||||
attn_block.attn_qkvw.data = mp_replace.qkv_copy(attn_block.attn_qkvw.data,
|
||||
qkvw)
|
||||
attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb.data, qkvb)
|
||||
attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw)
|
||||
attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb)
|
||||
|
||||
attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w)
|
||||
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob.data, dense_b)
|
||||
attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
|
||||
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
|
||||
|
||||
mpl_block = new_module.mlp
|
||||
mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data, _h4h_w)
|
||||
mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b.data, _h4h_b)
|
||||
mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data, _4hh_w)
|
||||
mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data, _4hh_b)
|
||||
|
||||
new_module.mlp.attn_nw = attn_nw if dense_b is None else attn_nw.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.mlp.attn_nb = attn_nb if dense_b is None else attn_nb.to(
|
||||
torch.cuda.current_device())
|
||||
if moe:
|
||||
gpu_index = torch.distributed.get_rank()
|
||||
gpu_index = 0
|
||||
for ep_index in range(local_ep_size):
|
||||
mpl_block[ep_index].inter_w.data = _h4h_w[
|
||||
gpu_index * local_ep_size + ep_index].to(
|
||||
torch.cuda.current_device())
|
||||
mpl_block[ep_index].inter_b.data = _h4h_b[
|
||||
gpu_index * local_ep_size + ep_index].to(
|
||||
torch.cuda.current_device())
|
||||
mpl_block[ep_index].output_w.data = _4hh_w[
|
||||
gpu_index * local_ep_size + ep_index].to(
|
||||
torch.cuda.current_device())
|
||||
mpl_block[ep_index].output_b.data = _4hh_b[
|
||||
gpu_index * local_ep_size + ep_index].to(
|
||||
torch.cuda.current_device())
|
||||
new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
|
||||
new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
|
||||
if moe_type == 'residual':
|
||||
new_module.res_mlp.inter_w.data = _res_h4h_w.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.res_mlp.inter_b.data = _res_h4h_b.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.res_mlp.output_w.data = _res_4hh_w.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.res_mlp.output_b.data = _res_4hh_b.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
|
||||
else:
|
||||
mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w, _h4h_w)
|
||||
mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b, _h4h_b)
|
||||
mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w, _4hh_w)
|
||||
mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b, _4hh_b)
|
||||
if attn_nw is None:
|
||||
new_module.mlp.attn_nw = attn_nw
|
||||
else:
|
||||
new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device())
|
||||
if attn_nb is None:
|
||||
new_module.mlp.attn_nb = attn_nb
|
||||
else:
|
||||
new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device())
|
||||
new_module.norm_w.data = input_nw.to(torch.cuda.current_device())
|
||||
new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
|
||||
else:
|
||||
|
@ -445,13 +554,17 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
def replace_fn(child, _policy, layer_id=0):
|
||||
if training:
|
||||
# copy relevant state from child -> new module
|
||||
new_module = replace_with_policy(child, _policy, preln=preln)
|
||||
new_module = replace_with_policy(child,
|
||||
_policy,
|
||||
triangular_masking,
|
||||
preln=preln)
|
||||
|
||||
else:
|
||||
# copy relevant state from child -> new module
|
||||
if replace_with_kernel_inject:
|
||||
new_module = replace_with_policy(child,
|
||||
_policy,
|
||||
triangular_masking,
|
||||
inference=True,
|
||||
preln=(_policy
|
||||
is not HFBertLayerPolicy),
|
||||
|
|
|
@ -55,16 +55,16 @@ class HFBertLayerPolicy(DSPolicy):
|
|||
HFBertLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attention.self.query.weight.data.shape[1], \
|
||||
return self.client_module.attention.self.query.weight.shape[1], \
|
||||
self.client_module.attention.self.num_attention_heads
|
||||
|
||||
def attention(self):
|
||||
qw = self.client_module.attention.self.query.weight.data
|
||||
qb = self.client_module.attention.self.query.bias.data
|
||||
kw = self.client_module.attention.self.key.weight.data
|
||||
kb = self.client_module.attention.self.key.bias.data
|
||||
vw = self.client_module.attention.self.value.weight.data
|
||||
vb = self.client_module.attention.self.value.bias.data
|
||||
qw = self.client_module.attention.self.query.weight
|
||||
qb = self.client_module.attention.self.query.bias
|
||||
kw = self.client_module.attention.self.key.weight
|
||||
kb = self.client_module.attention.self.key.bias
|
||||
vw = self.client_module.attention.self.value.weight
|
||||
vb = self.client_module.attention.self.value.bias
|
||||
|
||||
qkvw = torch.cat((qw, kw, vw), dim=0)
|
||||
qkvb = torch.cat((qb, kb, vb), dim=0)
|
||||
|
@ -72,8 +72,8 @@ class HFBertLayerPolicy(DSPolicy):
|
|||
return self.linear_layer, \
|
||||
qkvw, \
|
||||
qkvb, \
|
||||
self.client_module.attention.output.dense.weight.data, \
|
||||
self.client_module.attention.output.dense.bias.data, \
|
||||
self.client_module.attention.output.dense.weight, \
|
||||
self.client_module.attention.output.dense.bias, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
|
@ -82,9 +82,9 @@ class HFBertLayerPolicy(DSPolicy):
|
|||
else:
|
||||
intermediate_ff = self.client_module.intermediate.dense
|
||||
|
||||
return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \
|
||||
self.client_module.output.dense.weight.data, \
|
||||
self.client_module.output.dense.bias.data
|
||||
return self.linear_layer, intermediate_ff.weight, intermediate_ff.bias, \
|
||||
self.client_module.output.dense.weight, \
|
||||
self.client_module.output.dense.bias
|
||||
|
||||
def layerNorm(self):
|
||||
if self.preln:
|
||||
|
@ -93,10 +93,10 @@ class HFBertLayerPolicy(DSPolicy):
|
|||
else:
|
||||
attention_layernorm = self.client_module.attention.output.LayerNorm
|
||||
transformer_layernorm = self.client_module.output.LayerNorm
|
||||
return attention_layernorm.weight.data, \
|
||||
attention_layernorm.bias.data, \
|
||||
transformer_layernorm.weight.data, \
|
||||
transformer_layernorm.bias.data
|
||||
return attention_layernorm.weight, \
|
||||
attention_layernorm.bias, \
|
||||
transformer_layernorm.weight, \
|
||||
transformer_layernorm.bias
|
||||
|
||||
|
||||
class HFGPTNEOLayerPolicy(DSPolicy):
|
||||
|
@ -112,35 +112,35 @@ class HFGPTNEOLayerPolicy(DSPolicy):
|
|||
HFGPTNEOLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.attention.q_proj.weight.data.shape[1], \
|
||||
return self.client_module.attn.attention.q_proj.weight.shape[1], \
|
||||
self.client_module.attn.attention.num_heads
|
||||
|
||||
def attention(self):
|
||||
qw = self.client_module.attn.attention.q_proj.weight.data
|
||||
kw = self.client_module.attn.attention.k_proj.weight.data
|
||||
vw = self.client_module.attn.attention.v_proj.weight.data
|
||||
qw = self.client_module.attn.attention.q_proj.weight
|
||||
kw = self.client_module.attn.attention.k_proj.weight
|
||||
vw = self.client_module.attn.attention.v_proj.weight
|
||||
|
||||
qkvw = torch.cat((qw, kw, vw), dim=0)
|
||||
|
||||
return self.linear_layer, \
|
||||
qkvw, \
|
||||
None, \
|
||||
self.client_module.attn.attention.out_proj.weight.data, \
|
||||
self.client_module.attn.attention.out_proj.bias.data, \
|
||||
self.client_module.attn.attention.out_proj.weight, \
|
||||
self.client_module.attn.attention.out_proj.bias, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.c_fc.weight.data, \
|
||||
self.client_module.mlp.c_fc.bias.data, \
|
||||
self.client_module.mlp.c_proj.weight.data, \
|
||||
self.client_module.mlp.c_proj.bias.data
|
||||
self.client_module.mlp.c_fc.weight, \
|
||||
self.client_module.mlp.c_fc.bias, \
|
||||
self.client_module.mlp.c_proj.weight, \
|
||||
self.client_module.mlp.c_proj.bias
|
||||
|
||||
def layerNorm(self):
|
||||
return self.client_module.ln_2.weight.data, \
|
||||
self.client_module.ln_2.bias.data, \
|
||||
self.client_module.ln_1.weight.data, \
|
||||
self.client_module.ln_1.bias.data
|
||||
return self.client_module.ln_2.weight, \
|
||||
self.client_module.ln_2.bias, \
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
|
||||
class HFGPTJLayerPolicy(DSPolicy):
|
||||
|
@ -156,46 +156,47 @@ class HFGPTJLayerPolicy(DSPolicy):
|
|||
HFGPTJLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.q_proj.weight.data.shape[1], \
|
||||
return self.client_module.attn.q_proj.weight.shape[1], \
|
||||
self.client_module.attn.num_attention_heads
|
||||
|
||||
def attention(self):
|
||||
qw = self.client_module.attn.q_proj.weight.data
|
||||
kw = self.client_module.attn.k_proj.weight.data
|
||||
vw = self.client_module.attn.v_proj.weight.data
|
||||
qw = self.client_module.attn.q_proj.weight
|
||||
kw = self.client_module.attn.k_proj.weight
|
||||
vw = self.client_module.attn.v_proj.weight
|
||||
|
||||
qkvw = torch.cat((qw, kw, vw), dim=0)
|
||||
|
||||
return self.linear_layer, \
|
||||
qkvw, \
|
||||
None, \
|
||||
self.client_module.attn.out_proj.weight.data, \
|
||||
self.client_module.attn.out_proj.weight, \
|
||||
None, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.fc_in.weight.data, \
|
||||
self.client_module.mlp.fc_in.bias.data, \
|
||||
self.client_module.mlp.fc_out.weight.data, \
|
||||
self.client_module.mlp.fc_out.bias.data
|
||||
self.client_module.mlp.fc_in.weight, \
|
||||
self.client_module.mlp.fc_in.bias, \
|
||||
self.client_module.mlp.fc_out.weight, \
|
||||
self.client_module.mlp.fc_out.bias
|
||||
|
||||
def layerNorm(self):
|
||||
return None, \
|
||||
None, \
|
||||
self.client_module.ln_1.weight.data, \
|
||||
self.client_module.ln_1.bias.data
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
|
||||
class MegatronLayerPolicy(DSPolicy):
|
||||
_orig_layer_class = None
|
||||
version = 0
|
||||
moe_type = 'standard'
|
||||
|
||||
def __init__(self, client_module, version=0, inference=True):
|
||||
def __init__(self, client_module, inference=True):
|
||||
super().__init__(inference)
|
||||
self.client_module = client_module
|
||||
# we use megatron version to differentiate between the old and new
|
||||
# megatron-lm source code
|
||||
self.version = version
|
||||
if MegatronLayerPolicy._orig_layer_class is None:
|
||||
try:
|
||||
import megatron
|
||||
|
@ -205,35 +206,62 @@ class MegatronLayerPolicy(DSPolicy):
|
|||
MegatronLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attention.query_key_value.weight.data.shape[1], \
|
||||
return self.client_module.attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.attention.num_attention_heads
|
||||
|
||||
def attention(self):
|
||||
if self.inference:
|
||||
if self.version == 0:
|
||||
if MegatronLayerPolicy.version == 0:
|
||||
attention = self.client_module.attention
|
||||
else:
|
||||
attention = self.client_module.self_attention
|
||||
|
||||
return self.linear_layer, \
|
||||
attention.query_key_value.weight.data, \
|
||||
attention.query_key_value.bias.data, \
|
||||
attention.dense.weight.data, \
|
||||
attention.dense.bias.data, \
|
||||
attention.query_key_value.weight, \
|
||||
attention.query_key_value.bias, \
|
||||
attention.dense.weight, \
|
||||
attention.dense.bias, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.dense_h_to_4h.weight.data, \
|
||||
self.client_module.mlp.dense_h_to_4h.bias.data, \
|
||||
self.client_module.mlp.dense_4h_to_h.weight.data, \
|
||||
self.client_module.mlp.dense_4h_to_h.bias.data
|
||||
def mlp(self, moe_type='standard'):
|
||||
from deepspeed.moe.utils import has_moe_layers
|
||||
moe, _ = has_moe_layers(self.client_module)
|
||||
|
||||
if moe:
|
||||
moe_experts = self.client_module.mlp.deepspeed_moe.experts.deepspeed_experts if moe_type == 'standard' else \
|
||||
self.client_module.mlp.moe.deepspeed_moe.experts.deepspeed_experts
|
||||
num_experts = len(moe_experts)
|
||||
if moe_type == 'standard':
|
||||
return self.linear_layer, \
|
||||
[moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)]
|
||||
else:
|
||||
|
||||
return self.linear_layer, \
|
||||
[moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
|
||||
[moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)], \
|
||||
self.client_module.mlp.mlp.dense_h_to_4h.weight, \
|
||||
self.client_module.mlp.mlp.dense_h_to_4h.bias, \
|
||||
self.client_module.mlp.mlp.dense_4h_to_h.weight, \
|
||||
self.client_module.mlp.mlp.dense_4h_to_h.bias, \
|
||||
self.client_module.mlp.coefficient.weight
|
||||
|
||||
else:
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.dense_h_to_4h.weight, \
|
||||
self.client_module.mlp.dense_h_to_4h.bias, \
|
||||
self.client_module.mlp.dense_4h_to_h.weight, \
|
||||
self.client_module.mlp.dense_4h_to_h.bias
|
||||
|
||||
def layerNorm(self):
|
||||
return self.client_module.post_attention_layernorm.weight.data, \
|
||||
self.client_module.post_attention_layernorm.bias.data, \
|
||||
self.client_module.input_layernorm.weight.data, \
|
||||
self.client_module.input_layernorm.bias.data
|
||||
return self.client_module.post_attention_layernorm.weight, \
|
||||
self.client_module.post_attention_layernorm.bias, \
|
||||
self.client_module.input_layernorm.weight, \
|
||||
self.client_module.input_layernorm.bias
|
||||
|
||||
|
||||
class HFGPT2LayerPolicy(DSPolicy):
|
||||
|
@ -255,24 +283,24 @@ class HFGPT2LayerPolicy(DSPolicy):
|
|||
|
||||
def attention(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.attn.c_attn.weight.data, \
|
||||
self.client_module.attn.c_attn.bias.data, \
|
||||
self.client_module.attn.c_proj.weight.data, \
|
||||
self.client_module.attn.c_proj.bias.data, \
|
||||
self.client_module.attn.c_attn.weight, \
|
||||
self.client_module.attn.c_attn.bias, \
|
||||
self.client_module.attn.c_proj.weight, \
|
||||
self.client_module.attn.c_proj.bias, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.c_fc.weight.data, \
|
||||
self.client_module.mlp.c_fc.bias.data, \
|
||||
self.client_module.mlp.c_proj.weight.data, \
|
||||
self.client_module.mlp.c_proj.bias.data
|
||||
self.client_module.mlp.c_fc.weight, \
|
||||
self.client_module.mlp.c_fc.bias, \
|
||||
self.client_module.mlp.c_proj.weight, \
|
||||
self.client_module.mlp.c_proj.bias
|
||||
|
||||
def layerNorm(self):
|
||||
return self.client_module.ln_2.weight.data, \
|
||||
self.client_module.ln_2.bias.data, \
|
||||
self.client_module.ln_1.weight.data, \
|
||||
self.client_module.ln_1.bias.data
|
||||
return self.client_module.ln_2.weight, \
|
||||
self.client_module.ln_2.bias, \
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
|
||||
replace_policies = [
|
||||
|
|
|
@ -7,7 +7,7 @@ import copy
|
|||
|
||||
|
||||
class Experts(torch.nn.Module):
|
||||
def __init__(self, expert, num_local_experts=1):
|
||||
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
|
||||
super(Experts, self).__init__()
|
||||
|
||||
self.deepspeed_experts = torch.nn.ModuleList(
|
||||
|
@ -19,6 +19,7 @@ class Experts(torch.nn.Module):
|
|||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||
for name, param in expert.named_parameters():
|
||||
param.allreduce = False
|
||||
param.group_name = expert_group_name
|
||||
|
||||
def forward(self, inputs):
|
||||
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
||||
|
|
|
@ -61,14 +61,25 @@ class MoE(torch.nn.Module):
|
|||
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
|
||||
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
|
||||
|
||||
num_local_experts = num_experts // groups.get_expert_parallel_world_size()
|
||||
# Get the group name
|
||||
max_ep_size = groups.get_max_expert_size()
|
||||
if max_ep_size >= num_experts:
|
||||
self.expert_group_name = f"ep_size_{num_experts}"
|
||||
else:
|
||||
self.expert_group_name = f"ep_size_{max_ep_size}"
|
||||
|
||||
num_local_experts = 1 if num_experts < groups.get_expert_parallel_world_size(
|
||||
self.expert_group_name
|
||||
) else num_experts // groups.get_expert_parallel_world_size(
|
||||
self.expert_group_name)
|
||||
|
||||
log_dist(
|
||||
f'num_experts: {num_experts} | num_local_experts: {num_local_experts} | expert_parallel_size: {groups.get_expert_parallel_world_size()}',
|
||||
f'num_experts: {num_experts} | num_local_experts: {num_local_experts} | expert_parallel_size: {groups.get_expert_parallel_world_size(self.expert_group_name)}',
|
||||
[0])
|
||||
|
||||
self.num_experts = num_experts
|
||||
experts = Experts(expert, num_local_experts)
|
||||
self.num_local_experts = num_local_experts
|
||||
experts = Experts(expert, num_local_experts, self.expert_group_name)
|
||||
self.deepspeed_moe = MOELayer(TopKGate(hidden_size,
|
||||
num_experts,
|
||||
k,
|
||||
|
@ -80,7 +91,8 @@ class MoE(torch.nn.Module):
|
|||
use_rts),
|
||||
experts,
|
||||
num_local_experts,
|
||||
group=groups.get_expert_parallel_group(),
|
||||
group=groups.get_expert_parallel_group(
|
||||
self.expert_group_name),
|
||||
use_tutel=use_tutel)
|
||||
|
||||
def forward(self, hidden_states, used_token=None):
|
||||
|
|
|
@ -364,7 +364,7 @@ class TopKGate(Module):
|
|||
k: int = 1,
|
||||
capacity_factor: float = 1.0,
|
||||
eval_capacity_factor: float = 1.0,
|
||||
min_capacity: int = 4,
|
||||
min_capacity: int = 8,
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True) -> None:
|
||||
|
|
|
@ -1,5 +1,18 @@
|
|||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Dict
|
||||
import torch
|
||||
import deepspeed.utils.groups as groups
|
||||
from .layer import MoE
|
||||
|
||||
|
||||
def has_moe_layers(m):
|
||||
has_moe = False
|
||||
num_experts = 0
|
||||
for _, module in m.named_modules():
|
||||
if isinstance(module, MoE):
|
||||
has_moe = True
|
||||
num_experts = module.num_experts
|
||||
break
|
||||
return has_moe, num_experts
|
||||
|
||||
|
||||
def is_moe_param(param: torch.Tensor) -> bool:
|
||||
|
@ -44,3 +57,57 @@ def split_params_grads_into_shared_and_expert_params(
|
|||
else:
|
||||
shared_grads.append(p.grad.to(p.dtype))
|
||||
return shared_grads, expert_grads
|
||||
|
||||
|
||||
def split_params_into_different_moe_groups_for_optimizer(
|
||||
param_groups: Tuple[Dict]) -> Tuple[Dict]:
|
||||
"""Split parameters into different MoE groups for optimizer
|
||||
|
||||
Args:
|
||||
param_groups (Tuple[Dict]):
|
||||
The list of parameter groups to split
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]:
|
||||
list of MoE/non-MoE groups for optimizer
|
||||
"""
|
||||
if isinstance(param_groups, tuple):
|
||||
param_groups = list(param_groups) # Tuple cannot be modified
|
||||
elif isinstance(param_groups, dict):
|
||||
param_groups = [param_groups]
|
||||
elif not isinstance(param_groups, list):
|
||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||
|
||||
group_moe = {}
|
||||
|
||||
# Create the param MoE groups, leave param assign to next step
|
||||
for param_group in param_groups:
|
||||
group_moe[param_group['name']] = {}
|
||||
for key in groups.get_expert_data_parallel_group_dict().keys():
|
||||
group_moe[param_group['name']][key] = {}
|
||||
group_moe[param_group['name']][key]['name'] = key
|
||||
group_moe[param_group['name']][key]['moe'] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != 'name':
|
||||
if ori_key == 'params':
|
||||
group_moe[param_group['name']][key][ori_key] = []
|
||||
else:
|
||||
group_moe[
|
||||
param_group['name']][key][ori_key] = param_group[ori_key]
|
||||
# Assign param
|
||||
for param_group in param_groups:
|
||||
new_params = []
|
||||
for param in param_group['params']:
|
||||
if is_moe_param(param):
|
||||
group_moe[param_group['name']][param.group_name]['params'].append(param)
|
||||
# param_group['params'].remove(param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
param_group['params'] = new_params
|
||||
|
||||
# Flatten the moe groups
|
||||
for k, v in group_moe.items():
|
||||
for k1, v1 in v.items():
|
||||
param_groups.append(v1)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
from .inference.transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig
|
||||
from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
from .transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig
|
||||
from .moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference
|
||||
|
|
|
@ -0,0 +1,468 @@
|
|||
'''
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
import json
|
||||
import math
|
||||
import importlib
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import time
|
||||
from ... import op_builder
|
||||
#from ...inference.engine import inference_cuda_module, specialized_mode
|
||||
# Cuda modules will be imported if needed
|
||||
inference_cuda_module = None
|
||||
specialized_mode = None
|
||||
import torch.nn as nn
|
||||
from .transformer_inference import DeepSpeedSelfAttention, DeepSpeedInferenceConfig
|
||||
from ....moe.sharded_moe import TopKGate
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
|
||||
"""Initialize the DeepSpeed Transformer Config.
|
||||
Arguments:
|
||||
hidden_size: The hidden size of the transformer layer
|
||||
intermediate_size: The intermediate size of the feed-forward part of transformer layer
|
||||
heads: The number of heads in the self-attention of the transformer layer
|
||||
num_hidden_layers: The number of transformer layers
|
||||
layer_norm_eps: The epsilon value for the layer norm
|
||||
local_rank: Optional: The rank of GPU running the transformer kernel, it is not required
|
||||
to use if the model already set the current device, otherwise need to set it
|
||||
so that the transformer kernel can work on the right device
|
||||
mp_size (optional): This argument is mainly used to create the parameters on the kernel side
|
||||
using model-parallel architecture. If the client model already takes care of this, there is no
|
||||
need to pass this argument.
|
||||
fp16: Enable half-precision computation
|
||||
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
|
||||
stochastic_mode: Enable for high performance, please note that this flag has some level of
|
||||
non-determinism and can produce different results on different runs. However, we have seen
|
||||
that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
|
||||
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
|
||||
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
|
||||
|
||||
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
|
||||
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size=-1,
|
||||
intermediate_size=-1,
|
||||
heads=-1,
|
||||
num_hidden_layers=-1,
|
||||
layer_norm_eps=1e-12,
|
||||
local_rank=-1,
|
||||
mp_size=1,
|
||||
fp16=False,
|
||||
q_int8=False,
|
||||
pre_layer_norm=True,
|
||||
stochastic_mode=False,
|
||||
scale_attention=True,
|
||||
triangular_masking=True,
|
||||
local_attention=False,
|
||||
window_size=256,
|
||||
return_tuple=True,
|
||||
moe_experts=1,
|
||||
global_experts=1,
|
||||
k=1,
|
||||
capacity_factor=1.,
|
||||
eval_capacity_factor=1.,
|
||||
min_capacity=1,
|
||||
noisy_gate_policy=None,
|
||||
drop_tokens=True,
|
||||
use_rts=False,
|
||||
mlp_type='standard'):
|
||||
super(DeepSpeedMoEInferenceConfig,
|
||||
self).__init__(
|
||||
hidden_size,
|
||||
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
|
||||
heads,
|
||||
num_hidden_layers,
|
||||
layer_norm_eps,
|
||||
local_rank,
|
||||
mp_size,
|
||||
fp16,
|
||||
q_int8,
|
||||
pre_layer_norm,
|
||||
stochastic_mode,
|
||||
scale_attention,
|
||||
triangular_masking,
|
||||
local_attention,
|
||||
window_size,
|
||||
return_tuple)
|
||||
self.moe_experts = moe_experts
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
self.eval_capacity_factor = eval_capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_gate_policy = noisy_gate_policy
|
||||
self.drop_tokens = drop_tokens
|
||||
self.use_rts = use_rts
|
||||
self.global_experts = global_experts
|
||||
self.mlp_type = mlp_type
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
config = DeepSpeedInferenceConfig()
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
|
||||
class DeepSpeedMLPFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input,
|
||||
inter_w,
|
||||
inter_b,
|
||||
config,
|
||||
output_b,
|
||||
output_w,
|
||||
q_scales,
|
||||
q_groups,
|
||||
merge_count,
|
||||
mp_group,
|
||||
async_op):
|
||||
if config.q_int8:
|
||||
intermediate = inference_cuda_module.fused_gemm_gelu_int8(
|
||||
input,
|
||||
inter_w,
|
||||
inter_b,
|
||||
config.epsilon,
|
||||
q_scales[2],
|
||||
(q_groups * (2**merge_count)),
|
||||
config.pre_layer_norm)
|
||||
output = inference_cuda_module.vector_matmul_int8(intermediate,
|
||||
output_w,
|
||||
q_scales[3],
|
||||
q_groups,
|
||||
(merge_count))
|
||||
else:
|
||||
mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
|
||||
inference_cuda_module.fused_gemm_gelu_fp32
|
||||
|
||||
output = mlp_gemm_func(input,
|
||||
inter_w,
|
||||
inter_b,
|
||||
output_w,
|
||||
config.epsilon,
|
||||
config.pre_layer_norm,
|
||||
async_op)
|
||||
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
|
||||
dist.all_reduce(output, group=mp_group, async_op=async_op)
|
||||
|
||||
return output + output_b
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError('You are running with DeepSpeed Inference mode. \
|
||||
Please switch to Training mode for running backward!')
|
||||
|
||||
|
||||
class DeepSpeedMoEMLP(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
q_scales=None,
|
||||
q_groups=1,
|
||||
merge_count=1,
|
||||
mlp_extra_grouping=False,
|
||||
mp_group=None):
|
||||
super(DeepSpeedMoEMLP, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
interm_size = self.config.intermediate_size // (
|
||||
1 if mp_group is None else dist.get_world_size(group=mp_group))
|
||||
self.inter_w = nn.Parameter(torch.Tensor(self.config.hidden_size, interm_size))
|
||||
self.inter_b = nn.Parameter(torch.Tensor(interm_size))
|
||||
self.output_w = nn.Parameter(torch.Tensor((interm_size),
|
||||
self.config.hidden_size))
|
||||
self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
|
||||
# used for quantization
|
||||
self.q_scales = q_scales
|
||||
self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
|
||||
self.merge_count = int(math.log2(merge_count))
|
||||
self.mp_group = mp_group
|
||||
|
||||
def forward(self, input, async_op=False):
|
||||
return DeepSpeedMLPFunction.apply(input,
|
||||
self.inter_w,
|
||||
self.inter_b,
|
||||
self.config,
|
||||
self.output_b,
|
||||
self.output_w,
|
||||
self.q_scales,
|
||||
self.q_groups,
|
||||
self.merge_count,
|
||||
self.mp_group,
|
||||
async_op)
|
||||
|
||||
|
||||
class DeepSpeedMoEInference(nn.Module):
|
||||
"""Initialize the DeepSpeed MoE Transformer Layer.
|
||||
Arguments:
|
||||
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
|
||||
layer_id will be 0,1,2...23 when each layer object is instantiated
|
||||
config: An object of DeepSpeedInferenceConfig
|
||||
mp_group: Model parallelism group initialized on the modeling side.
|
||||
quantize_scales: This argument groups all the layers' scales used for quantization
|
||||
quantize_groups: Number of groups used for quantizing the model
|
||||
merge_count: Shows the number of model-parallel checkpoints merged before running inference.
|
||||
We use this argument to control the quantization scale for the model parameters if a bigger
|
||||
quantize-grouping than 1 is used.
|
||||
mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
|
||||
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
|
||||
for specific downstream tasks.
|
||||
"""
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
mp_group=None,
|
||||
ep_group=None,
|
||||
expert_mp_group=None,
|
||||
quantize_scales=None,
|
||||
quantize_groups=1,
|
||||
merge_count=1,
|
||||
mlp_extra_grouping=False,
|
||||
qkv_merging=False):
|
||||
super(DeepSpeedMoEInference, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.config.layer_id = DeepSpeedMoEInference.layer_id
|
||||
global inference_cuda_module
|
||||
global specialized_mode
|
||||
if inference_cuda_module is None:
|
||||
specialized_mode = False
|
||||
if hasattr(op_builder, 'InferenceSpecializedBuilder'):
|
||||
builder = op_builder.InferenceSpecializedBuilder()
|
||||
if builder.is_compatible():
|
||||
inference_cuda_module = builder.load()
|
||||
specialized_mode = True
|
||||
else:
|
||||
inference_cuda_module = op_builder.InferenceBuilder().load()
|
||||
else:
|
||||
inference_cuda_module = op_builder.InferenceBuilder().load()
|
||||
self.config.specialized_mode = specialized_mode
|
||||
|
||||
DeepSpeedMoEInference.layer_id += 1
|
||||
self.attention = DeepSpeedSelfAttention(self.config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
qkv_merging)
|
||||
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
|
||||
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
|
||||
if config.mlp_type == 'residual':
|
||||
self.res_mlp = DeepSpeedMoEMLP(config,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
mlp_extra_grouping,
|
||||
mp_group)
|
||||
self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2))
|
||||
self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \
|
||||
inference_cuda_module.softmax_fp32
|
||||
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
|
||||
inference_cuda_module.vector_matmul_fp32
|
||||
|
||||
config.mp_size = 1
|
||||
self.mlp = nn.ModuleList(
|
||||
DeepSpeedMoEMLP(config,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
mlp_extra_grouping,
|
||||
expert_mp_group) for i in range(self.config.moe_experts))
|
||||
|
||||
self.moe_gate = TopKGate(self.config.hidden_size,
|
||||
self.config.global_experts,
|
||||
self.config.k,
|
||||
self.config.capacity_factor,
|
||||
self.config.eval_capacity_factor,
|
||||
self.config.min_capacity,
|
||||
self.config.noisy_gate_policy,
|
||||
self.config.drop_tokens,
|
||||
self.config.use_rts)
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.mp_group = mp_group
|
||||
self.expert_mp_group = expert_mp_group
|
||||
|
||||
print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__)
|
||||
|
||||
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
|
||||
inference_cuda_module.bias_residual_fp32
|
||||
self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \
|
||||
inference_cuda_module.layer_norm_fp32
|
||||
self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.fp16 or self.config.q_int8 else \
|
||||
inference_cuda_module.einsum_sec_sm_ecm_fp32
|
||||
|
||||
def res_coef_func(self, inp, async_op):
|
||||
inp = self.vector_matmul_func(inp, self.res_coef, async_op)
|
||||
return self.coef_func(inp, torch.empty(1), False, False, False, 256, async_op)
|
||||
|
||||
def moe_gate_einsum(self, attention_output):
|
||||
_, combined_weights, dispatch_mask, _ = self.moe_gate(
|
||||
attention_output.view(-1, self.config.hidden_size),
|
||||
None,
|
||||
)
|
||||
dispatched_attention = self.einsum_sec_sm_ecm(
|
||||
dispatch_mask.type_as(attention_output),
|
||||
attention_output.view(-1,
|
||||
self.config.hidden_size))
|
||||
return dispatched_attention, combined_weights
|
||||
|
||||
def expert_exec(self, dispatched_input):
|
||||
dispatched_input = dispatched_input.reshape(
|
||||
self.config.global_experts // self.config.moe_experts,
|
||||
self.config.moe_experts,
|
||||
-1,
|
||||
self.config.hidden_size)
|
||||
|
||||
chunks = dispatched_input.chunk(self.config.moe_experts, dim=1)
|
||||
expert_outputs = torch.empty((
|
||||
self.config.moe_experts,
|
||||
chunks[0].shape[0],
|
||||
) + chunks[0].shape[2:],
|
||||
dtype=dispatched_input.dtype,
|
||||
device=dispatched_input.device)
|
||||
for chunk, expert in zip(chunks, range(len(self.mlp))):
|
||||
expert_outputs[expert] = self.mlp[expert](chunk.view(
|
||||
-1,
|
||||
dispatched_input.shape[-2],
|
||||
dispatched_input.shape[-1]))
|
||||
return expert_outputs
|
||||
|
||||
def _alltoall(self, dispatched_attention):
|
||||
if dist.get_world_size(group=self.ep_group) > 1:
|
||||
dispatched_input = torch.empty_like(dispatched_attention)
|
||||
dist.all_to_all_single(dispatched_input,
|
||||
dispatched_attention,
|
||||
group=self.ep_group)
|
||||
return dispatched_input
|
||||
else:
|
||||
return dispatched_attention
|
||||
|
||||
def scale_expert_output(self, attention_output, expert_output, combined_weights):
|
||||
combined_output = torch.matmul(
|
||||
combined_weights.type_as(attention_output).reshape(
|
||||
combined_weights.shape[0],
|
||||
-1),
|
||||
expert_output.reshape(-1,
|
||||
expert_output.shape[-1]))
|
||||
return combined_output.reshape(attention_output.shape)
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
input_mask=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
get_present=False,
|
||||
encoder_output=None,
|
||||
enc_dec_attn_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False):
|
||||
get_present = (get_present or get_key_value or use_cache)
|
||||
input_mask = input_mask if attention_mask is None else attention_mask
|
||||
input_type = input.dtype
|
||||
|
||||
if (self.config.fp16 or self.config.q_int8) \
|
||||
and input.dtype == torch.float:
|
||||
input = input.half()
|
||||
|
||||
with torch.no_grad():
|
||||
attention_output = self.attention(input,
|
||||
input_mask,
|
||||
head_mask,
|
||||
layer_past,
|
||||
get_present,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
self.norm_w,
|
||||
self.norm_b)
|
||||
|
||||
if get_present:
|
||||
attention_output, p_key, p_value = attention_output[0:3]
|
||||
presents = (p_key, p_value)
|
||||
elif output_attentions:
|
||||
attention_output, _, _, context_output = attention_output[0:4]
|
||||
else:
|
||||
attention_output = attention_output[0]
|
||||
|
||||
residual_add = attention_output + self.attention.attn_ob
|
||||
attention_output = self.ds_layernorm(residual_add,
|
||||
self.attn_nw,
|
||||
self.attn_nb,
|
||||
self.config.epsilon)
|
||||
|
||||
if self.config.mlp_type == 'residual':
|
||||
res_mlp_out = self.res_mlp(attention_output, async_op=True)
|
||||
res_coef_out = self.res_coef_func(attention_output, async_op=True)
|
||||
|
||||
if self.expert_mp_group is not None:
|
||||
tensor_list = [
|
||||
torch.empty_like(attention_output)
|
||||
for _ in range(dist.get_world_size(group=self.expert_mp_group))
|
||||
]
|
||||
tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output
|
||||
dist.all_gather(tensor_list,
|
||||
attention_output,
|
||||
group=self.expert_mp_group)
|
||||
attention_output = torch.cat(tensor_list).contiguous()
|
||||
|
||||
############## MoE Gating + Experts ###############
|
||||
dispatched_attention, combined_weights = self.moe_gate_einsum(attention_output)
|
||||
dispatched_input = self._alltoall(dispatched_attention)
|
||||
expert_outputs = self.expert_exec(dispatched_input)
|
||||
expert_output = self._alltoall(expert_outputs)
|
||||
output = self.scale_expert_output(attention_output,
|
||||
expert_output,
|
||||
combined_weights)
|
||||
################################################
|
||||
|
||||
if self.expert_mp_group is not None:
|
||||
output = output.split(output.shape[0] //
|
||||
dist.get_world_size(group=self.expert_mp_group),
|
||||
dim=0)[dist.get_rank(group=self.expert_mp_group)]
|
||||
|
||||
if self.config.mlp_type == 'residual':
|
||||
inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
|
||||
|
||||
output = self.bias_residual_func(output, residual_add, torch.empty(1))
|
||||
|
||||
if not self.config.pre_layer_norm:
|
||||
output = self.ds_layernorm(output,
|
||||
self.norm_w,
|
||||
self.norm_b,
|
||||
self.config.epsilon)
|
||||
|
||||
if input_type != output.dtype:
|
||||
output = output.to(input_type)
|
||||
|
||||
if get_present:
|
||||
output = (output, presents)
|
||||
|
||||
if self.config.return_tuple:
|
||||
return output if type(output) is tuple else (output, )
|
||||
else:
|
||||
return output
|
|
@ -9,11 +9,11 @@ from torch import nn
|
|||
from torch.autograd import Function
|
||||
import time
|
||||
from ... import op_builder
|
||||
#from ...inference.engine import inference_cuda_module, specialized_mode
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
# Cuda modules will be imported if needed
|
||||
inference_cuda_module = None
|
||||
specialized_mode = None
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TransformerConfig():
|
||||
|
@ -83,13 +83,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
self.mp_size = mp_size
|
||||
self.q_int8 = q_int8
|
||||
self.scale_attention = scale_attention
|
||||
self.specialized_mode = None
|
||||
self.triangular_masking = triangular_masking
|
||||
self.local_attention = local_attention
|
||||
self.window_size = window_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.return_tuple = return_tuple
|
||||
self.mlp_after_attn = mlp_after_attn
|
||||
self.specialized_mode = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
|
@ -131,6 +131,11 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
q_groups,
|
||||
merge_count,
|
||||
qkv_merging):
|
||||
|
||||
#while len(input_mask.shape) < 4:
|
||||
# input_mask = input_mask.unsqueeze(0)
|
||||
input_mask = torch.empty(1, device='cuda')
|
||||
|
||||
def _transpose_for_scores(x, key=False, reshape=False):
|
||||
attention_head_size = x.shape[-1] // num_attention_heads_per_partition
|
||||
new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
|
||||
|
@ -270,7 +275,7 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
config.epsilon,
|
||||
(attn_qkvb is not None))
|
||||
context_layer, key_layer, value_layer = compute_attention(qkv_out[0], input_mask)
|
||||
output = vector_matmul_func(context_layer, attn_ow)
|
||||
output = vector_matmul_func(context_layer, attn_ow, False)
|
||||
|
||||
return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm
|
||||
|
||||
|
@ -282,7 +287,9 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
attn_qkvb,
|
||||
q_scales[0],
|
||||
(q_groups * (3 if qkv_merging else 1) * (2**merge_count)))
|
||||
|
||||
else:
|
||||
#import pdb;pdb.set_trace()
|
||||
qkv_out = inference_cuda_module.qkv_gemm_int8(
|
||||
input,
|
||||
attn_qkvw,
|
||||
|
@ -306,8 +313,8 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
else:
|
||||
output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
|
||||
|
||||
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
|
||||
torch.distributed.all_reduce(output, group=mp_group)
|
||||
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
|
||||
dist.all_reduce(output, group=mp_group)
|
||||
|
||||
return (output, key_layer, value_layer, context_layer, inp_norm)
|
||||
|
||||
|
@ -318,6 +325,8 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
|
||||
|
||||
class DeepSpeedSelfAttention(nn.Module):
|
||||
num_layers = 0
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
mp_group=None,
|
||||
|
@ -327,7 +336,8 @@ class DeepSpeedSelfAttention(nn.Module):
|
|||
qkv_merging=False):
|
||||
super(DeepSpeedSelfAttention, self).__init__()
|
||||
self.config = config
|
||||
|
||||
self.config.layer_id = DeepSpeedSelfAttention.num_layers
|
||||
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
|
||||
self.attn_qkvw = nn.Parameter(
|
||||
torch.Tensor(self.config.hidden_size,
|
||||
(self.config.hidden_size // self.config.mp_size) * 3))
|
||||
|
@ -436,7 +446,13 @@ class DeepSpeedMLPFunction(Function):
|
|||
(merge_count))
|
||||
else:
|
||||
if attn_nw is None:
|
||||
output = fused_gemm_gelu(input, inter_w, inter_b, output_w)
|
||||
output = fused_gemm_gelu(input,
|
||||
inter_w,
|
||||
inter_b,
|
||||
output_w,
|
||||
config.epsilon,
|
||||
config.pre_layer_norm,
|
||||
False)
|
||||
else:
|
||||
(intermediate,
|
||||
residual_add) = mlp_gemm_func(input,
|
||||
|
@ -448,10 +464,10 @@ class DeepSpeedMLPFunction(Function):
|
|||
attn_nb,
|
||||
config.epsilon,
|
||||
config.pre_layer_norm)
|
||||
output = vector_matmul_func(intermediate, output_w)
|
||||
output = vector_matmul_func(intermediate, output_w, False)
|
||||
|
||||
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
|
||||
torch.distributed.all_reduce(output, group=mp_group)
|
||||
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
|
||||
dist.all_reduce(output, group=mp_group)
|
||||
|
||||
if attn_nw is not None:
|
||||
output = bias_residual_func(output, residual_add, output_b)
|
||||
|
@ -557,19 +573,10 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
DeepSpeedTransformerInference.layer_id += 1
|
||||
|
||||
global inference_cuda_module
|
||||
global specialized_mode
|
||||
if inference_cuda_module is None:
|
||||
specialized_mode = False
|
||||
if hasattr(op_builder, 'InferenceSpecializedBuilder'):
|
||||
builder = op_builder.InferenceSpecializedBuilder()
|
||||
if builder.is_compatible():
|
||||
inference_cuda_module = builder.load()
|
||||
specialized_mode = True
|
||||
else:
|
||||
inference_cuda_module = op_builder.InferenceBuilder().load()
|
||||
else:
|
||||
inference_cuda_module = op_builder.InferenceBuilder().load()
|
||||
self.config.specialized_mode = specialized_mode
|
||||
builder = op_builder.InferenceBuilder()
|
||||
inference_cuda_module = builder.load()
|
||||
|
||||
print("DeepSpeed Transformer Inference config is ", self.config.__dict__)
|
||||
|
||||
self.attention = DeepSpeedSelfAttention(self.config,
|
||||
|
@ -649,7 +656,6 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.mlp.output_b)
|
||||
|
||||
output = output.to(input_type)
|
||||
|
||||
if get_present:
|
||||
output = (output, presents)
|
||||
|
||||
|
|
|
@ -204,7 +204,7 @@ class DeepSpeedEngine(Module):
|
|||
self.gas_boundary_ctr = 0
|
||||
self.dist_backend = "nccl"
|
||||
self.has_moe_layers = False
|
||||
self.num_experts = None
|
||||
self.num_experts = []
|
||||
self.gate_modules = []
|
||||
self.moe_layers = []
|
||||
self._step_applied = False
|
||||
|
@ -957,11 +957,12 @@ class DeepSpeedEngine(Module):
|
|||
return True
|
||||
|
||||
for p in self.module.parameters():
|
||||
# Broadcast the model for different parameters
|
||||
if hasattr(p, 'allreduce') and not p.allreduce:
|
||||
if torch.is_tensor(p) and is_replicated(p):
|
||||
dist.broadcast(p,
|
||||
self.expert_broadcast_src_rank,
|
||||
group=self.expert_data_parallel_group)
|
||||
self.expert_broadcast_src_rank[p.group_name],
|
||||
group=self.expert_data_parallel_group[p.group_name])
|
||||
else:
|
||||
if torch.is_tensor(p) and is_replicated(p):
|
||||
dist.broadcast(p,
|
||||
|
@ -1004,8 +1005,7 @@ class DeepSpeedEngine(Module):
|
|||
for _, module in self.module.named_modules():
|
||||
if isinstance(module, MoE):
|
||||
self.has_moe_layers = True
|
||||
self.num_experts = module.num_experts
|
||||
break
|
||||
self.num_experts.append(module.num_experts)
|
||||
|
||||
if self.has_moe_layers:
|
||||
for _, module in self.module.named_modules():
|
||||
|
@ -1055,13 +1055,15 @@ class DeepSpeedEngine(Module):
|
|||
|
||||
if self.has_moe_layers:
|
||||
# No assert needed because this will only be true if MoE Layer creation was successful
|
||||
self.expert_data_parallel_group = groups.get_expert_data_parallel_group()
|
||||
self.expert_parallel_group = groups.get_expert_parallel_group()
|
||||
self.ep_world_size = groups.get_expert_parallel_world_size()
|
||||
self.expert_broadcast_src_rank = _get_global_rank(
|
||||
groups.get_expert_data_parallel_group(),
|
||||
0)
|
||||
|
||||
self.expert_data_parallel_group = groups.get_expert_data_parallel_group_dict(
|
||||
)
|
||||
self.expert_parallel_group = groups.get_expert_parallel_group_dict()
|
||||
self.expert_broadcast_src_rank = {}
|
||||
for _key in self.expert_data_parallel_group.keys(): # _key is a string
|
||||
self.expert_broadcast_src_rank[_key] = _get_global_rank(
|
||||
groups.get_expert_data_parallel_group(_key),
|
||||
0)
|
||||
if not self.amp_enabled():
|
||||
self._broadcast_model()
|
||||
|
||||
|
@ -2113,7 +2115,12 @@ class DeepSpeedEngine(Module):
|
|||
self.allreduce_and_copy(small_bucket, dp_group)
|
||||
|
||||
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
|
||||
grads, expert_grads = [], []
|
||||
grads = []
|
||||
expert_grads = {}
|
||||
if self.has_moe_layers:
|
||||
for key in self.expert_data_parallel_group.keys():
|
||||
expert_grads[key] = []
|
||||
|
||||
for param_name, param in self.module.named_parameters():
|
||||
if hasattr(param, 'allreduce') and not param.allreduce:
|
||||
is_moe_param = True
|
||||
|
@ -2129,19 +2136,19 @@ class DeepSpeedEngine(Module):
|
|||
dtype=param.dtype,
|
||||
device=param.device)
|
||||
if is_moe_param:
|
||||
expert_grads.append(param.grad.data)
|
||||
expert_grads[param.group_name].append(param.grad.data)
|
||||
else:
|
||||
grads.append(param.grad.data)
|
||||
else:
|
||||
grad_data = param.grad.data
|
||||
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
|
||||
if is_moe_param:
|
||||
expert_grads.append(SparseTensor(grad_data))
|
||||
expert_grads[param.group_name].append(SparseTensor(grad_data))
|
||||
else:
|
||||
grads.append(SparseTensor(grad_data))
|
||||
else:
|
||||
if is_moe_param:
|
||||
expert_grads.append(grad_data)
|
||||
expert_grads[param.group_name].append(grad_data)
|
||||
else:
|
||||
grads.append(grad_data)
|
||||
|
||||
|
@ -2162,19 +2169,20 @@ class DeepSpeedEngine(Module):
|
|||
numel_per_bucket=elements_per_buffer)
|
||||
|
||||
if self.has_moe_layers:
|
||||
expert_split_buckets = split_half_float_double_sparse(expert_grads)
|
||||
for i, bucket_tuple in enumerate(expert_split_buckets):
|
||||
bucket_type, bucket = bucket_tuple
|
||||
if bucket_type == SparseTensor.type():
|
||||
self.sparse_allreduce_no_retain(
|
||||
bucket,
|
||||
groups.get_expert_data_parallel_group())
|
||||
else:
|
||||
# Separate between diff groups
|
||||
self.allreduce_no_retain(
|
||||
bucket,
|
||||
dp_group=groups.get_expert_data_parallel_group(),
|
||||
numel_per_bucket=elements_per_buffer)
|
||||
for ep_name, expert_grads_group in expert_grads.items():
|
||||
expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
|
||||
for i, bucket_tuple in enumerate(expert_split_buckets):
|
||||
bucket_type, bucket = bucket_tuple
|
||||
if bucket_type == SparseTensor.type():
|
||||
self.sparse_allreduce_no_retain(
|
||||
bucket,
|
||||
groups.get_expert_data_parallel_group(ep_name))
|
||||
else:
|
||||
# Separate between diff groups
|
||||
self.allreduce_no_retain(
|
||||
bucket,
|
||||
dp_group=groups.get_expert_data_parallel_group(ep_name),
|
||||
numel_per_bucket=elements_per_buffer)
|
||||
|
||||
def sparse_allreduce_no_retain(self, bucket, dp_group):
|
||||
allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)
|
||||
|
@ -2260,25 +2268,58 @@ class DeepSpeedEngine(Module):
|
|||
sd = self.module.state_dict(destination, prefix, keep_vars)
|
||||
return sd
|
||||
|
||||
def load_moe_state_dict(self, checkpoint_path, tag, state_dict):
|
||||
expp_rank = groups.get_expert_parallel_rank()
|
||||
def load_moe_state_dict(self, checkpoint_path, tag, state_dict, old_moe_load):
|
||||
|
||||
num_local_experts = self.num_experts // self.ep_world_size
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
expert_state_dict = torch.load(self._get_expert_ckpt_name(
|
||||
checkpoint_path,
|
||||
global_expert_id,
|
||||
tag),
|
||||
map_location=torch.device('cpu'))
|
||||
if old_moe_load:
|
||||
expp_rank = groups.get_expert_data_parallel_rank(
|
||||
groups.get_max_expert_size_name())
|
||||
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
|
||||
f'{moe_str_prefix}{local_expert_id}')
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
state_dict.update(expert_state_dict)
|
||||
num_local_experts = max(
|
||||
self.num_experts) // groups.get_expert_parallel_world_size(
|
||||
groups.get_max_expert_size_name())
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
expert_state_dict = torch.load(self._get_expert_ckpt_name(
|
||||
checkpoint_path,
|
||||
-1, # -1 means ingore layer_id
|
||||
global_expert_id,
|
||||
tag),
|
||||
map_location=torch.device('cpu'))
|
||||
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
|
||||
f'{moe_str_prefix}{local_expert_id}')
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
state_dict.update(expert_state_dict)
|
||||
|
||||
else:
|
||||
moe_layer_id = 0
|
||||
for n_module, module in self.module.named_modules():
|
||||
if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
|
||||
group_name = module.expert_group_name
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = groups.get_expert_parallel_rank(group_name)
|
||||
# loop all local_experts
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
expert_state_dict = torch.load(self._get_expert_ckpt_name(
|
||||
checkpoint_path,
|
||||
moe_layer_id,
|
||||
global_expert_id,
|
||||
tag),
|
||||
map_location=torch.device('cpu'))
|
||||
# print(expert_state_dict.keys())
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(
|
||||
f'{moe_str_prefix}{global_expert_id}',
|
||||
f'{moe_str_prefix}{local_expert_id}')
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
state_dict.update(expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
||||
def load_module_state_dict(self, state_dict, strict=True):
|
||||
self.module.load_state_dict(state_dict, strict=strict)
|
||||
|
@ -2328,12 +2369,21 @@ class DeepSpeedEngine(Module):
|
|||
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
|
||||
return ckpt_name
|
||||
|
||||
def _get_expert_ckpt_name(self, checkpoints_path, expert_id, tag):
|
||||
def _get_expert_ckpt_name(self, checkpoints_path, layer_id, expert_id, tag):
|
||||
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
|
||||
ckpt_name = os.path.join(
|
||||
checkpoints_path,
|
||||
str(tag),
|
||||
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
|
||||
if layer_id <= -1:
|
||||
# Used to support old checkpoint loading
|
||||
ckpt_name = os.path.join(
|
||||
checkpoints_path,
|
||||
str(tag),
|
||||
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
|
||||
else:
|
||||
# Used to support new checkpoint loading
|
||||
ckpt_name = os.path.join(
|
||||
checkpoints_path,
|
||||
str(tag),
|
||||
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
|
||||
)
|
||||
return ckpt_name
|
||||
|
||||
def _get_all_ckpt_names(self, checkpoints_path, tag):
|
||||
|
@ -2433,7 +2483,14 @@ class DeepSpeedEngine(Module):
|
|||
self._curr_ckpt_path = os.path.join(load_dir, tag)
|
||||
|
||||
if self.has_moe_layers:
|
||||
self.load_moe_state_dict(load_dir, tag, state_dict=checkpoint['module'])
|
||||
# print(checkpoint.keys())
|
||||
old_moe_load = False
|
||||
if not isinstance(checkpoint['num_experts'], list):
|
||||
old_moe_load = True
|
||||
self.load_moe_state_dict(load_dir,
|
||||
tag,
|
||||
state_dict=checkpoint['module'],
|
||||
old_moe_load=old_moe_load)
|
||||
|
||||
self.load_module_state_dict(state_dict=checkpoint['module'],
|
||||
strict=load_module_strict)
|
||||
|
@ -2446,7 +2503,8 @@ class DeepSpeedEngine(Module):
|
|||
self.optimizer.refresh_fp32_params()
|
||||
else:
|
||||
if self.has_moe_layers:
|
||||
expp_rank = groups.get_expert_parallel_rank()
|
||||
largest_group_name = groups.get_max_expert_size_name()
|
||||
expp_rank = groups.get_expert_parallel_rank(largest_group_name)
|
||||
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
|
||||
optim_checkpoint = torch.load(optim_load_path,
|
||||
map_location=torch.device('cpu'))
|
||||
|
@ -2700,64 +2758,76 @@ class DeepSpeedEngine(Module):
|
|||
|
||||
return True
|
||||
|
||||
def _get_moe_state_dict(self, full_state_dict, num_local_experts, expp_rank):
|
||||
"""Compute moe and non moe state dict from complete local model state dict
|
||||
key : global_expert_id
|
||||
value : state_dict
|
||||
experts_state_dict =
|
||||
{
|
||||
'0': {
|
||||
'models.seq2seq.encoder.layers.0.experts.moe.experts.experts.0.fc1.weight' <class 'torch.Tensor'>,
|
||||
'models.seq2seq.encoder.layers.1.experts.moe.experts.experts.0.fc1.weight' <class 'torch.Tensor'>,
|
||||
'models.seq2seq.encoder.layers.2.experts.moe.experts.experts.0.fc1.weight' <class 'torch.Tensor'>,
|
||||
...
|
||||
},
|
||||
'1' : {
|
||||
...
|
||||
}
|
||||
}
|
||||
|
||||
returns experts_state_dict, model_state_dict
|
||||
def _get_non_moe_state_dict(self, full_state_dict):
|
||||
"""
|
||||
Get the state dict of the non-moe layers
|
||||
"""
|
||||
experts_state_dict, moe_state_dict = defaultdict(dict), {}
|
||||
for key in list(full_state_dict.keys()):
|
||||
if 'expert' in key and 'moe.gate.wg.weight' not in key:
|
||||
moe_state_dict[key] = full_state_dict.pop(key)
|
||||
non_moe_state_dict = full_state_dict
|
||||
full_state_dict.pop(key)
|
||||
|
||||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
|
||||
for key in list(moe_state_dict.keys()):
|
||||
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
|
||||
|
||||
local_expert_id = None
|
||||
if not m:
|
||||
logger.warn(f'No expert found in key {key}.')
|
||||
else:
|
||||
local_expert_id = m.group(1)
|
||||
|
||||
global_expert_id = expp_rank * \
|
||||
num_local_experts + int(local_expert_id)
|
||||
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
|
||||
f'{moe_str_prefix}{global_expert_id}')
|
||||
experts_state_dict[str(global_expert_id)][expert_key] = moe_state_dict.pop(
|
||||
key)
|
||||
|
||||
return experts_state_dict, non_moe_state_dict
|
||||
return full_state_dict
|
||||
|
||||
def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
|
||||
|
||||
save_path = self._get_ckpt_name(save_dir, tag)
|
||||
# A hack to save the checkpointing directory. Pipeline parallelism overrides
|
||||
# module_state_dict() and uses this path to save the model. module_state_dict()
|
||||
# then instead just returns None.
|
||||
|
||||
# Using layer_#_export_# to save the model's expert state_dict
|
||||
moe_layer_id = 0
|
||||
for n_module, module in self.module.named_modules():
|
||||
if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
|
||||
group_name = module.expert_group_name
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = groups.get_expert_parallel_rank(group_name)
|
||||
exp_dp_rank = groups.get_expert_data_parallel_rank(group_name)
|
||||
# print(expp_rank, exp_dp_rank)
|
||||
if exp_dp_rank != 0:
|
||||
moe_layer_id += 1
|
||||
continue
|
||||
|
||||
# get all moe parameters
|
||||
moe_state_dict = {}
|
||||
for n, p in module.state_dict().items():
|
||||
if 'expert' in n and 'moe.gate.wg.weight' not in n:
|
||||
moe_state_dict[n_module + '.' + n] = p
|
||||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
|
||||
# print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines
|
||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||
experts_state_dict = defaultdict(dict)
|
||||
for key in list(moe_state_dict.keys()):
|
||||
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
|
||||
|
||||
local_expert_id = None
|
||||
if not m:
|
||||
logger.warn(f'No expert found in key {key}.')
|
||||
else:
|
||||
local_expert_id = m.group(1)
|
||||
|
||||
global_expert_id = expp_rank * \
|
||||
num_local_experts + int(local_expert_id)
|
||||
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
|
||||
f'{moe_str_prefix}{global_expert_id}')
|
||||
experts_state_dict[str(
|
||||
global_expert_id)][expert_key] = moe_state_dict.pop(key)
|
||||
|
||||
# let save the moe parameters
|
||||
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||
# save the moe parameters
|
||||
moe_save_path = self._get_expert_ckpt_name(
|
||||
save_dir,
|
||||
moe_layer_id,
|
||||
global_expert_id,
|
||||
tag)
|
||||
torch.save(expert_state_dict, moe_save_path)
|
||||
moe_layer_id += 1
|
||||
|
||||
self._curr_ckpt_path = os.path.join(save_dir, tag)
|
||||
""""
|
||||
experts_state_dict = {
|
||||
'e_id' : state_dict_for_eid
|
||||
}
|
||||
"""
|
||||
expp_rank = groups.get_expert_parallel_rank()
|
||||
exp_dp_rank = groups.get_expert_data_parallel_rank()
|
||||
|
||||
largest_group_name = groups.get_max_expert_size_name()
|
||||
expp_rank = groups.get_expert_parallel_rank(largest_group_name)
|
||||
exp_dp_rank = groups.get_expert_data_parallel_rank(largest_group_name)
|
||||
|
||||
# In the case of E + D parallelism, only the
|
||||
# first expert parallel group should save the expert weights
|
||||
|
@ -2765,17 +2835,6 @@ class DeepSpeedEngine(Module):
|
|||
if exp_dp_rank != 0:
|
||||
return
|
||||
|
||||
num_local_experts = self.num_experts // self.ep_world_size
|
||||
experts_state_dict, model_state_dict = self._get_moe_state_dict(
|
||||
self.module_state_dict(), num_local_experts, expp_rank)
|
||||
|
||||
# Each rank saves its local experts
|
||||
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||
expert_save_dir = self._get_expert_ckpt_name(save_dir, global_expert_id, tag)
|
||||
logger.info(
|
||||
f'Saving model expert {global_expert_id} checkpoint: {expert_save_dir}')
|
||||
torch.save(expert_state_dict, expert_save_dir)
|
||||
|
||||
# Save optimizer states. They are different across each exp parallel rank.
|
||||
optimizer_state = {
|
||||
'optimizer':
|
||||
|
@ -2787,6 +2846,9 @@ class DeepSpeedEngine(Module):
|
|||
tag,
|
||||
expp_rank))
|
||||
|
||||
# get non-moe parameters
|
||||
model_state_dict = self._get_non_moe_state_dict(self.module_state_dict())
|
||||
|
||||
if expp_rank == 0:
|
||||
# TODO: update num experts info,.. in checkpoint
|
||||
state = {
|
||||
|
|
|
@ -141,6 +141,7 @@ class FP16_Optimizer(object):
|
|||
"""
|
||||
Not supporting closure.
|
||||
"""
|
||||
|
||||
# First compute norm for all group so we know if there is overflow
|
||||
grads_groups_flat = []
|
||||
norm_groups = []
|
||||
|
|
|
@ -188,9 +188,11 @@ class CheckOverflow(object):
|
|||
# In this case, we need to do an all_reduce across
|
||||
# the expert_parallel_group, so that if there was
|
||||
# an overflow due to expert weights, we detect it
|
||||
|
||||
# Only need to check groups.get_largest_expert_parallel_group()
|
||||
dist.all_reduce(overflow_gpu,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=groups.get_expert_parallel_group())
|
||||
group=groups.get_max_expert_parallel_group())
|
||||
if self.mpu is not None:
|
||||
torch.distributed.all_reduce(overflow_gpu,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
|
@ -241,7 +243,7 @@ class CheckOverflow(object):
|
|||
# overflows, we detect it here
|
||||
dist.all_reduce(overflow_gpu,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=groups.get_expert_parallel_group())
|
||||
group=groups.get_max_expert_parallel_group())
|
||||
if self.zero_reduce_scatter:
|
||||
torch.distributed.all_reduce(overflow_gpu,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
|
|
|
@ -116,7 +116,6 @@ class WeightQuantization(object):
|
|||
return sd, all_scales
|
||||
|
||||
def model_quantize(self, model, quantize_policy, quantize_bits, groups):
|
||||
|
||||
all_scales = []
|
||||
|
||||
def quantize_fn(layer, policy_cls):
|
||||
|
|
|
@ -181,7 +181,6 @@ class DeepSpeedZeroOptimizer(object):
|
|||
self.contiguous_gradients = contiguous_gradients or cpu_offload
|
||||
|
||||
self.has_moe_layers = has_moe_layers
|
||||
|
||||
if self.has_moe_layers:
|
||||
self._configure_moe_settings()
|
||||
self._global_grad_norm = 0.
|
||||
|
@ -254,6 +253,7 @@ class DeepSpeedZeroOptimizer(object):
|
|||
self.round_robin_bit16_groups = []
|
||||
self.round_robin_bit16_indices = []
|
||||
|
||||
# Use different parallel to do all_to_all_reduce related things
|
||||
# padding on each partition for alignment purposes
|
||||
self.groups_padding = []
|
||||
# loop to deal with groups
|
||||
|
@ -507,9 +507,10 @@ class DeepSpeedZeroOptimizer(object):
|
|||
for i, group in enumerate(self.optimizer.param_groups):
|
||||
if self.is_moe_group(group):
|
||||
assert all([is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params"
|
||||
self.real_dp_process_group[i] = self.expert_dp_process_group
|
||||
self.real_dp_process_group[i] = self.expert_dp_process_group[
|
||||
group['name']]
|
||||
self.partition_count[i] = dist.get_world_size(
|
||||
group=self.expert_dp_process_group)
|
||||
group=self.expert_dp_process_group[group['name']])
|
||||
self.is_moe_param_group.append(True)
|
||||
else:
|
||||
self.is_moe_param_group.append(False)
|
||||
|
@ -901,12 +902,15 @@ class DeepSpeedZeroOptimizer(object):
|
|||
process_group = self.dp_process_group
|
||||
#Averages gradients at parameter level if ipg has a moe param
|
||||
#Otherwise averaging is done at the entire buffer level at the end of the loop
|
||||
# MoE param have different groups
|
||||
if self.ipg_bucket_has_moe_params:
|
||||
process_group = self.expert_dp_process_group if is_moe_param(
|
||||
param) else self.dp_process_group
|
||||
process_group = self.expert_dp_process_group[
|
||||
param.group_name] if is_moe_param(
|
||||
param) else self.dp_process_group
|
||||
param.grad.data.div_(dist.get_world_size(group=process_group))
|
||||
|
||||
partition_ids = self.param_to_partition_ids[i][param_id]
|
||||
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}"
|
||||
partition_size = self.partition_size[i]
|
||||
# Get all partition ids + their offsets
|
||||
partition_ids_w_offsets = []
|
||||
|
@ -1790,11 +1794,11 @@ class DeepSpeedZeroOptimizer(object):
|
|||
for i, norm in enumerate(norm_groups):
|
||||
if self.is_moe_param_group[i]:
|
||||
scaled_norm = norm * 1.0 / float(
|
||||
dist.get_world_size(group=self.ep_process_group))
|
||||
dist.get_world_size(group=self.real_dp_process_group[i]))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=self.ep_process_group)
|
||||
dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i])
|
||||
norm_groups[i] = scaled_norm_tensor.item()
|
||||
|
||||
def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
|
@ -2025,7 +2029,8 @@ class DeepSpeedZeroOptimizer(object):
|
|||
sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict
|
||||
]
|
||||
if self.is_moe_group(self.optimizer.param_groups[i]):
|
||||
ranks = self.get_ep_ranks()
|
||||
ranks = self.get_ep_ranks(
|
||||
group_name=self.optimizer.param_groups[i]['name'])
|
||||
merged_partitions = [merged_partitions[i] for i in ranks]
|
||||
flat_merged_partitions = self.flatten_dense_tensors_aligned(
|
||||
merged_partitions,
|
||||
|
@ -2075,11 +2080,11 @@ class DeepSpeedZeroOptimizer(object):
|
|||
else:
|
||||
self.optimizer.state[p][key] = saved
|
||||
|
||||
def get_ep_ranks(self, rank=0):
|
||||
def get_ep_ranks(self, rank=0, group_name=None):
|
||||
from deepspeed.utils import groups
|
||||
expert_parallel_size_ = groups.get_expert_parallel_world_size()
|
||||
expert_parallel_size_ = groups.get_expert_parallel_world_size(group_name)
|
||||
world_size = groups.get_data_parallel_world_size()
|
||||
rank = groups.get_expert_parallel_rank()
|
||||
rank = groups.get_expert_parallel_rank(group_name)
|
||||
ranks = range(rank, world_size, expert_parallel_size_)
|
||||
return list(ranks)
|
||||
|
||||
|
@ -2096,7 +2101,8 @@ class DeepSpeedZeroOptimizer(object):
|
|||
]
|
||||
|
||||
if self.is_moe_group(self.optimizer.param_groups[i]):
|
||||
ranks = self.get_ep_ranks()
|
||||
ranks = self.get_ep_ranks(
|
||||
group_name=self.optimizer.param_groups[i]['name'])
|
||||
all_partition_group_states = [
|
||||
all_partition_group_states[i] for i in ranks
|
||||
]
|
||||
|
@ -2151,7 +2157,7 @@ class DeepSpeedZeroOptimizer(object):
|
|||
self.overflow = current_rank_sd['overflow']
|
||||
|
||||
ckpt_version = current_rank_sd.get("ds_version", False)
|
||||
assert ckpt_version, f"Empty ds_version! {error_str}"
|
||||
assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
|
||||
ckpt_version = pkg_version.parse(ckpt_version)
|
||||
|
||||
# zero stage 1 mode
|
||||
|
|
|
@ -55,11 +55,14 @@ from deepspeed.utils import logger, log_dist
|
|||
# Model parallel group that the current rank belongs to.
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
# Expert parallel group that the current rank belongs to.
|
||||
_EXPERT_PARALLEL_GROUP = None
|
||||
_EXPERT_PARALLEL_GROUP = None # {"32_expert": parallel_group}
|
||||
# Expert data parallel group that the current rank belongs to.
|
||||
_EXPERT_DATA_PARALLEL_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
# Max EP SIZE
|
||||
_MAX_EP_SIZE = None
|
||||
_MAX_EP_SIZE_NAME = None
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
|
@ -68,7 +71,7 @@ def ensure_divisibility(numerator, denominator):
|
|||
numerator, denominator)
|
||||
|
||||
|
||||
def initialize(ep_size=1, mpu=None):
|
||||
def initialize(ep_size=1, mpu=None, num_ep_list=None):
|
||||
"""
|
||||
Process groups initialization supporting expert (E), data (D), and model (M) parallelism. DeepSpeed considers
|
||||
the following scenarios w.r.t. process group creation.
|
||||
|
@ -98,18 +101,32 @@ def initialize(ep_size=1, mpu=None):
|
|||
engine = deepspeed.initialize(model, mpu=mpu) # passing mpu is optional in this case
|
||||
|
||||
Arguments:
|
||||
ep_size (int, optional): default=1, expert parallel size
|
||||
ep_size (int, optional): default=1, maximum expert parallel size, which should be divisible/divided by the world size.
|
||||
by each element in num_ep_list.
|
||||
mpu (module, optional): default=None, model parallel unit (e.g., from Megatron)
|
||||
that describes model/data parallel ranks.
|
||||
num_ep_list (list, optional): default=None, list of number of expert parallel sizes in each MoE layer.
|
||||
|
||||
"""
|
||||
|
||||
if num_ep_list is None:
|
||||
num_ep_list = [ep_size]
|
||||
|
||||
assert max(num_ep_list) >= ep_size, f"ep_size={ep_size} is larger than the largest num_ep_list={max(num_ep_list)}, you should reduce expert parallel size"
|
||||
|
||||
num_ep_list = list(set(num_ep_list)) # remove duplicates
|
||||
num_ep_list.sort() # sort in ascending order
|
||||
for num_ep in num_ep_list:
|
||||
assert num_ep > 0, 'num_ep must be positive'
|
||||
assert num_ep % ep_size == 0 or ep_size % num_ep == 0, 'num_ep must be divisible/divided by ep_size'
|
||||
|
||||
if mpu is not None:
|
||||
log_dist(message="initializing deepspeed groups using mpu", ranks=[0])
|
||||
initialize_model_and_expert_parallel(ep_size, mpu)
|
||||
initialize_model_and_expert_parallel(ep_size, mpu, num_ep_list)
|
||||
else:
|
||||
log_dist(message="initializing deepspeed groups", ranks=[0])
|
||||
initialize_model_parallel(1)
|
||||
initialize_expert_parallel(ep_size)
|
||||
initialize_expert_parallel(ep_size, num_ep_list)
|
||||
|
||||
|
||||
def initialize_model_parallel(model_parallel_size_):
|
||||
|
@ -163,7 +180,7 @@ def initialize_model_parallel(model_parallel_size_):
|
|||
_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
|
||||
def initialize_expert_parallel(expert_parallel_size_):
|
||||
def initialize_expert_parallel(expert_parallel_size_, num_ep_list_=None):
|
||||
"""
|
||||
Initialize expert plus data parallel groups.
|
||||
|
||||
|
@ -176,9 +193,18 @@ def initialize_expert_parallel(expert_parallel_size_):
|
|||
"""
|
||||
assert torch.distributed.is_initialized()
|
||||
|
||||
global _MAX_EP_SIZE
|
||||
global _MAX_EP_SIZE_NAME
|
||||
_MAX_EP_SIZE = expert_parallel_size_
|
||||
_MAX_EP_SIZE_NAME = f"ep_size_{expert_parallel_size_}"
|
||||
|
||||
if num_ep_list_ is None:
|
||||
num_ep_list_ = [expert_parallel_size_]
|
||||
|
||||
log_dist(
|
||||
'initializing deepspeed expert parallel group with size {}'.format(
|
||||
expert_parallel_size_),
|
||||
'initializing deepspeed expert parallel group with max size {} for number expert list {}'
|
||||
.format(expert_parallel_size_,
|
||||
num_ep_list_),
|
||||
[0])
|
||||
world_size = get_data_parallel_world_size()
|
||||
rank = get_data_parallel_rank()
|
||||
|
@ -190,33 +216,58 @@ def initialize_expert_parallel(expert_parallel_size_):
|
|||
global _EXPERT_DATA_PARALLEL_GROUP
|
||||
assert _EXPERT_DATA_PARALLEL_GROUP is None, \
|
||||
'expert data parallel group is already initialized'
|
||||
for i in range(expert_parallel_size_):
|
||||
ranks = range(i, world_size, expert_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
|
||||
# TODO: remove
|
||||
log_dist(
|
||||
f'creating expert data parallel process group with ranks: {list(ranks)}',
|
||||
[0])
|
||||
if i == (rank % expert_parallel_size_):
|
||||
_EXPERT_DATA_PARALLEL_GROUP = group
|
||||
_EXPERT_DATA_PARALLEL_GROUP = {}
|
||||
|
||||
for num_ep in num_ep_list_:
|
||||
# Build the data parallel groups for each num_ep
|
||||
# We will have two cases
|
||||
# 1. num_ep >= expert_parallel_size_, we can assign the same group to to num_ep from expert_parallel_size_ to num_ep
|
||||
# 2. num_ep < expert_parallel_size_, we will need to create the new group
|
||||
if num_ep >= expert_parallel_size_:
|
||||
if f"ep_size_{expert_parallel_size_}" not in _EXPERT_DATA_PARALLEL_GROUP:
|
||||
for i in range(expert_parallel_size_):
|
||||
# generate all groups
|
||||
ranks = range(i, world_size, expert_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if i == (rank % expert_parallel_size_):
|
||||
# get the correct group
|
||||
_EXPERT_DATA_PARALLEL_GROUP[
|
||||
f"ep_size_{expert_parallel_size_}"] = group
|
||||
else:
|
||||
for i in range(num_ep):
|
||||
ranks = range(i, world_size, num_ep)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if i == (rank % num_ep):
|
||||
_EXPERT_DATA_PARALLEL_GROUP[f"ep_size_{num_ep}"] = group
|
||||
|
||||
# Build the expert parallel groups.
|
||||
global _EXPERT_PARALLEL_GROUP
|
||||
assert _EXPERT_PARALLEL_GROUP is None, \
|
||||
'expert parallel group is already initialized'
|
||||
for i in range(world_size // expert_parallel_size_):
|
||||
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
|
||||
# TODO: remove
|
||||
log_dist(f'creating expert parallel process group with ranks: {list(ranks)}',
|
||||
[0])
|
||||
if i == (rank // expert_parallel_size_):
|
||||
_EXPERT_PARALLEL_GROUP = group
|
||||
_EXPERT_PARALLEL_GROUP = {}
|
||||
|
||||
for num_ep in num_ep_list_:
|
||||
# Similar as above we will need to think about two cases
|
||||
if num_ep >= expert_parallel_size_:
|
||||
if f"ep_size_{expert_parallel_size_}" not in _EXPERT_PARALLEL_GROUP:
|
||||
for i in range(world_size // expert_parallel_size_):
|
||||
ranks = range(i * expert_parallel_size_,
|
||||
(i + 1) * expert_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if i == (rank // expert_parallel_size_):
|
||||
_EXPERT_PARALLEL_GROUP[
|
||||
f"ep_size_{expert_parallel_size_}"] = group
|
||||
else:
|
||||
for i in range(world_size // num_ep):
|
||||
ranks = range(i * num_ep, (i + 1) * num_ep)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if i == (rank // num_ep):
|
||||
_EXPERT_PARALLEL_GROUP[f"ep_size_{num_ep}"] = group
|
||||
|
||||
|
||||
def initialize_model_and_expert_parallel(expert_parallel_size_, mpu):
|
||||
def initialize_model_and_expert_parallel(expert_parallel_size_, mpu, num_ep_list_=None):
|
||||
"""
|
||||
Initialize Expert groups based on MPU groups.
|
||||
|
||||
|
@ -233,13 +284,21 @@ def initialize_model_and_expert_parallel(expert_parallel_size_, mpu):
|
|||
assert mpu.model_parallel_is_initialized(), "model parallel group is not initialized"
|
||||
model_parallel_size_ = mpu.get_model_parallel_world_size()
|
||||
|
||||
global _MAX_EP_SIZE
|
||||
global _MAX_EP_SIZE_NAME
|
||||
_MAX_EP_SIZE = expert_parallel_size_
|
||||
_MAX_EP_SIZE_NAME = f"ep_size_{expert_parallel_size_}"
|
||||
|
||||
if num_ep_list_ is None:
|
||||
num_ep_list = [expert_parallel_size_]
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
dp_world_size = mpu.get_data_parallel_world_size()
|
||||
dp_rank = mpu.get_data_parallel_rank()
|
||||
|
||||
log_dist(
|
||||
f"Initializing deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, and data parallel size {world_size}",
|
||||
f"Initializing deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
|
||||
[0])
|
||||
|
||||
global _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
|
||||
|
@ -259,31 +318,54 @@ def initialize_model_and_expert_parallel(expert_parallel_size_, mpu):
|
|||
assert _EXPERT_PARALLEL_GROUP is None, \
|
||||
'expert parallel group is already initialized'
|
||||
|
||||
for j in range(model_parallel_size_):
|
||||
for i in range(expert_parallel_size_):
|
||||
ranks = range(i * model_parallel_size_ + j,
|
||||
world_size,
|
||||
expert_parallel_size_ * model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
_EXPERT_DATA_PARALLEL_GROUP = {}
|
||||
_EXPERT_PARALLEL_GROUP = {}
|
||||
|
||||
# TODO: remove
|
||||
log_dist(
|
||||
f'creating expert data parallel process group with ranks: {list(ranks)}',
|
||||
[0])
|
||||
if rank in list(ranks):
|
||||
_EXPERT_DATA_PARALLEL_GROUP = group
|
||||
for num_ep in num_ep_list_:
|
||||
for j in range(model_parallel_size_):
|
||||
# For data parallel
|
||||
# Similar as initialize_expert_parallel we will need to think about two cases
|
||||
if num_ep >= expert_parallel_size_:
|
||||
#TODO: refactor this part of code to check condition in outer for-loop
|
||||
if True: #f"ep_size_{expert_parallel_size_}" not in _EXPERT_DATA_PARALLEL_GROUP:
|
||||
for i in range(expert_parallel_size_):
|
||||
ranks = range(i * model_parallel_size_ + j,
|
||||
world_size,
|
||||
expert_parallel_size_ * model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in list(ranks):
|
||||
_EXPERT_DATA_PARALLEL_GROUP[
|
||||
f"ep_size_{expert_parallel_size_}"] = group
|
||||
else:
|
||||
for i in range(num_ep):
|
||||
ranks = range(i * model_parallel_size_ + j,
|
||||
world_size,
|
||||
num_ep * model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in list(ranks):
|
||||
_EXPERT_DATA_PARALLEL_GROUP[f"ep_size_{num_ep}"] = group
|
||||
|
||||
for i in range(dp_world_size // expert_parallel_size_):
|
||||
ranks = range(i * expert_parallel_size_ * model_parallel_size_ + j,
|
||||
(i + 1) * expert_parallel_size_ * model_parallel_size_,
|
||||
model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
|
||||
# TODO: remove
|
||||
log_dist(f'creating expert parallel process group with ranks: {list(ranks)}',
|
||||
[0])
|
||||
if rank in list(ranks):
|
||||
_EXPERT_PARALLEL_GROUP = group
|
||||
# For expert parallel
|
||||
if num_ep >= expert_parallel_size_:
|
||||
#TODO: refactor this part of code to check condition in outer for-loop
|
||||
if True: #f"ep_size_{expert_parallel_size_}" not in _EXPERT_PARALLEL_GROUP:
|
||||
for i in range(dp_world_size // expert_parallel_size_):
|
||||
ranks = range(
|
||||
i * expert_parallel_size_ * model_parallel_size_ + j,
|
||||
(i + 1) * expert_parallel_size_ * model_parallel_size_,
|
||||
model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in list(ranks):
|
||||
_EXPERT_PARALLEL_GROUP[
|
||||
f"ep_size_{expert_parallel_size_}"] = group
|
||||
else:
|
||||
for i in range(dp_world_size // num_ep):
|
||||
ranks = range(i * num_ep * model_parallel_size_ + j,
|
||||
(i + 1) * num_ep * model_parallel_size_,
|
||||
model_parallel_size_)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in list(ranks):
|
||||
_EXPERT_PARALLEL_GROUP[f"ep_size_{num_ep}"] = group
|
||||
|
||||
|
||||
def is_initialized():
|
||||
|
@ -314,15 +396,48 @@ def get_model_parallel_group():
|
|||
return _MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_expert_parallel_group():
|
||||
def get_max_expert_parallel_group():
|
||||
"""Get the max expert parallel size."""
|
||||
return get_expert_parallel_group(get_max_expert_size_name())
|
||||
|
||||
|
||||
def get_max_expert_size_name():
|
||||
"""Get the maximum experts group size name in all group."""
|
||||
assert _MAX_EP_SIZE_NAME is not None, \
|
||||
'max expert parallel size is not initialized'
|
||||
return _MAX_EP_SIZE_NAME
|
||||
|
||||
|
||||
def get_max_expert_size():
|
||||
"""Get the maximum experts group size in all group."""
|
||||
assert _MAX_EP_SIZE is not None, \
|
||||
'max expert parallel size is not initialized'
|
||||
return _MAX_EP_SIZE
|
||||
|
||||
|
||||
def get_expert_parallel_group(group_name):
|
||||
"""Get the expert parallel group the caller rank belongs to."""
|
||||
assert _EXPERT_PARALLEL_GROUP is not None, \
|
||||
'expert parallel group is not initialized'
|
||||
return _EXPERT_PARALLEL_GROUP[group_name]
|
||||
|
||||
|
||||
def get_expert_parallel_group_dict():
|
||||
"""Get the expert parallel group dict."""
|
||||
assert _EXPERT_PARALLEL_GROUP is not None, \
|
||||
'expert parallel group is not initialized'
|
||||
return _EXPERT_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_expert_data_parallel_group():
|
||||
def get_expert_data_parallel_group(group_name):
|
||||
"""Get the expert data parallel group the caller rank belongs to."""
|
||||
assert _EXPERT_DATA_PARALLEL_GROUP is not None, \
|
||||
'expert data parallel group is not initialized'
|
||||
return _EXPERT_DATA_PARALLEL_GROUP[group_name]
|
||||
|
||||
|
||||
def get_expert_data_parallel_group_dict():
|
||||
"""Get the expert data parallel group dict."""
|
||||
assert _EXPERT_DATA_PARALLEL_GROUP is not None, \
|
||||
'expert data parallel group is not initialized'
|
||||
return _EXPERT_DATA_PARALLEL_GROUP
|
||||
|
@ -340,14 +455,15 @@ def get_model_parallel_world_size():
|
|||
return torch.distributed.get_world_size(group=get_model_parallel_group())
|
||||
|
||||
|
||||
def get_expert_parallel_world_size():
|
||||
def get_expert_parallel_world_size(group_name):
|
||||
"""Return world size for the expert parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_expert_parallel_group())
|
||||
return torch.distributed.get_world_size(group=get_expert_parallel_group(group_name))
|
||||
|
||||
|
||||
def get_expert_data_parallel_world_size():
|
||||
def get_expert_data_parallel_world_size(group_name):
|
||||
"""Return world size for the expert data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_expert_data_parallel_group())
|
||||
return torch.distributed.get_world_size(
|
||||
group=get_expert_data_parallel_group(group_name))
|
||||
|
||||
|
||||
def get_model_parallel_rank():
|
||||
|
@ -355,9 +471,9 @@ def get_model_parallel_rank():
|
|||
return torch.distributed.get_rank(group=get_model_parallel_group())
|
||||
|
||||
|
||||
def get_expert_parallel_rank():
|
||||
def get_expert_parallel_rank(group_name):
|
||||
"""Return my rank for the expert parallel group."""
|
||||
return torch.distributed.get_rank(group=get_expert_parallel_group())
|
||||
return torch.distributed.get_rank(group=get_expert_parallel_group(group_name))
|
||||
|
||||
|
||||
def get_model_parallel_src_rank():
|
||||
|
@ -368,17 +484,17 @@ def get_model_parallel_src_rank():
|
|||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_expert_parallel_src_rank():
|
||||
def get_expert_parallel_src_rank(group_name):
|
||||
"""Calculate the global rank corresponding to a local rank zero
|
||||
in the expert parallel group."""
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = get_expert_parallel_world_size()
|
||||
local_world_size = get_expert_parallel_world_size(group_name)
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_expert_data_parallel_rank():
|
||||
def get_expert_data_parallel_rank(group_name):
|
||||
"""Return my rank for the expert data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_expert_data_parallel_group())
|
||||
return torch.distributed.get_rank(group=get_expert_data_parallel_group(group_name))
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
|
|
|
@ -25,8 +25,8 @@ __op_builders__ = [
|
|||
TransformerBuilder(),
|
||||
StochasticTransformerBuilder(),
|
||||
AsyncIOBuilder(),
|
||||
InferenceBuilder(),
|
||||
UtilsBuilder(),
|
||||
QuantizerBuilder()
|
||||
QuantizerBuilder(),
|
||||
InferenceBuilder()
|
||||
]
|
||||
ALL_OPS = {op.name: op for op in __op_builders__}
|
||||
|
|
|
@ -465,7 +465,6 @@ class CUDAOpBuilder(OpBuilder):
|
|||
- `cross_compile_archs` uses ; separator.
|
||||
|
||||
"""
|
||||
|
||||
ccs = []
|
||||
if self.jit_mode:
|
||||
# Compile for underlying architectures since we know those at runtime
|
||||
|
@ -537,10 +536,12 @@ class CUDAOpBuilder(OpBuilder):
|
|||
return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
|
||||
|
||||
def nvcc_args(self):
|
||||
cuda_major, _ = installed_cuda_version()
|
||||
args = [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-std=c++17' if sys.platform == "win32" else '-std=c++14',
|
||||
'-std=c++17'
|
||||
if sys.platform == "win32" and cuda_major > 10 else '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
|
|
6
setup.py
6
setup.py
|
@ -20,6 +20,7 @@ import shutil
|
|||
import subprocess
|
||||
import warnings
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools.command import egg_info
|
||||
import time
|
||||
|
||||
torch_available = True
|
||||
|
@ -180,6 +181,7 @@ if sys.platform == "win32":
|
|||
# It needs Administrator privilege to create symlinks on Windows.
|
||||
create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc')
|
||||
create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder')
|
||||
egg_info.manifest_maker.template = 'MANIFEST_win.in'
|
||||
|
||||
# Parse the DeepSpeed version string from version.txt
|
||||
version_str = open('version.txt', 'r').read().strip()
|
||||
|
@ -239,7 +241,9 @@ setup(name='deepspeed',
|
|||
install_requires=install_requires,
|
||||
extras_require=extras_require,
|
||||
packages=find_packages(exclude=["docker",
|
||||
"third_party"]),
|
||||
"third_party",
|
||||
"csrc",
|
||||
"op_builder"]),
|
||||
include_package_data=True,
|
||||
scripts=[
|
||||
'bin/deepspeed',
|
||||
|
|
|
@ -7,6 +7,7 @@ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|||
from deepspeed.utils import groups
|
||||
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
|
||||
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
|
||||
|
||||
from deepspeed.runtime.pipe.topology import *
|
||||
|
||||
|
@ -991,29 +992,10 @@ def test_checkpoint_moe_and_zero(tmpdir, ep_size, load_optim_states):
|
|||
hidden_dim = 16
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
|
||||
def create_moe_param_groups(model):
|
||||
from deepspeed.moe.utils import is_moe_param
|
||||
|
||||
params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'}
|
||||
moe_params_with_weight_decay = {
|
||||
'params': [],
|
||||
'moe': True,
|
||||
'name': 'weight_decay_moe_params'
|
||||
}
|
||||
|
||||
for module_ in model.modules():
|
||||
moe_params_with_weight_decay['params'].extend([
|
||||
p for n,
|
||||
p in list(module_._parameters.items())
|
||||
if p is not None and is_moe_param(p)
|
||||
])
|
||||
params_with_weight_decay['params'].extend([
|
||||
p for n,
|
||||
p in list(module_._parameters.items())
|
||||
if p is not None and not is_moe_param(p)
|
||||
])
|
||||
|
||||
return params_with_weight_decay, moe_params_with_weight_decay
|
||||
def create_param_groups(model):
|
||||
# param group must have a random unique name (for now)
|
||||
# TODO: clean-up this requirement, the unique name should not be required here
|
||||
return {'params': model.parameters(), 'name': 'random-unique-name'}
|
||||
|
||||
@distributed_test(world_size=[4])
|
||||
def _helper(args):
|
||||
|
@ -1022,7 +1004,10 @@ def test_checkpoint_moe_and_zero(tmpdir, ep_size, load_optim_states):
|
|||
SimpleMoEModel(hidden_dim=hidden_dim,
|
||||
num_experts=ep_size) for _ in range(2)
|
||||
]
|
||||
params = [create_moe_param_groups(model) for model in models]
|
||||
params = [
|
||||
split_params_into_different_moe_groups_for_optimizer(
|
||||
create_param_groups(model)) for model in models
|
||||
]
|
||||
optimizers = [torch.optim.AdamW(params=param) for param in params]
|
||||
checkpoint_correctness_verification(args,
|
||||
models=models,
|
||||
|
|
|
@ -52,7 +52,7 @@ def test_moe(tmpdir, ep_size):
|
|||
#dist_init_required=False -- parameterize to True/False?
|
||||
|
||||
assert dist.get_world_size() == groups.get_data_parallel_world_size(), "incorrect data parallel world size"
|
||||
assert ep_size == groups.get_expert_parallel_world_size(), "incorrect expert parallel world size"
|
||||
assert ep_size == groups.get_expert_parallel_world_size(groups.get_max_expert_size_name()), "incorrect expert parallel world size"
|
||||
|
||||
data_loader = sequence_dataloader(model=model,
|
||||
total_samples=50,
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.5.10
|
||||
0.6.0
|
||||
|
|
Загрузка…
Ссылка в новой задаче