зеркало из https://github.com/microsoft/DeepSpeed.git
GPT-J inference support (#1670)
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Родитель
7e857aab9a
Коммит
289c3f9ba4
|
@ -0,0 +1,129 @@
|
|||
#include "custom_cuda_layers.h"
|
||||
|
||||
#include <cuda_profiler_api.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void apply_rotary_pos_emb(float* mixed_query,
|
||||
float* key_layer,
|
||||
unsigned rotary_dim,
|
||||
unsigned seq_len,
|
||||
unsigned seq_offset,
|
||||
unsigned num_heads,
|
||||
unsigned head_size,
|
||||
unsigned total_count)
|
||||
{
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
int id = threadIdx.x;
|
||||
int gid = id >> 5;
|
||||
int lane = id & 0x1f;
|
||||
|
||||
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
|
||||
unsigned offset = head_id * head_size;
|
||||
|
||||
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
|
||||
|
||||
if (head_id < total_count) {
|
||||
while (lane < rotary_dim) {
|
||||
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
|
||||
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
|
||||
float q = mixed_query[offset + lane];
|
||||
float k = key_layer[offset + lane];
|
||||
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
|
||||
float q_rot = (q * rotary_sign);
|
||||
float k_rot = (k * rotary_sign);
|
||||
q_rot = g.shfl_xor(q_rot, 1);
|
||||
k_rot = g.shfl_xor(k_rot, 1);
|
||||
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
|
||||
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
|
||||
|
||||
mixed_query[offset + lane] = q;
|
||||
key_layer[offset + lane] = k;
|
||||
|
||||
lane += WARP_SIZE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void apply_rotary_pos_emb(__half* mixed_query,
|
||||
__half* key_layer,
|
||||
unsigned rotary_dim,
|
||||
unsigned seq_len,
|
||||
unsigned seq_offset,
|
||||
unsigned num_heads,
|
||||
unsigned head_size,
|
||||
unsigned total_count)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
|
||||
unsigned head_id = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (head_id < total_count) {
|
||||
unsigned offset = head_id * head_size + threadIdx.x;
|
||||
unsigned tid = threadIdx.x;
|
||||
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
while (tid < rotary_dim) {
|
||||
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
|
||||
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
|
||||
float q = (float)mixed_query[offset];
|
||||
float k = (float)key_layer[offset];
|
||||
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
|
||||
float q_rot = (q * rotary_sign);
|
||||
float k_rot = (k * rotary_sign);
|
||||
q_rot = g.shfl_xor(q_rot, 1);
|
||||
k_rot = g.shfl_xor(k_rot, 1);
|
||||
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
|
||||
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
|
||||
|
||||
mixed_query[offset] = (__half)q;
|
||||
key_layer[offset] = (__half)k;
|
||||
|
||||
tid += blockDim.x;
|
||||
offset += blockDim.x;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_apply_rotary_pos_emb(T* mixed_query,
|
||||
T* key_layer,
|
||||
unsigned head_size,
|
||||
unsigned seq_len,
|
||||
unsigned rotary_dim,
|
||||
unsigned offset,
|
||||
unsigned num_heads,
|
||||
unsigned batch,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int total_count = batch * num_heads * seq_len;
|
||||
dim3 block_dims(1024);
|
||||
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
|
||||
|
||||
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
|
||||
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
|
||||
}
|
||||
|
||||
template void launch_apply_rotary_pos_emb<float>(float*,
|
||||
float*,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
cudaStream_t);
|
||||
template void launch_apply_rotary_pos_emb<__half>(__half*,
|
||||
__half*,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
unsigned,
|
||||
cudaStream_t);
|
|
@ -264,3 +264,107 @@ template void launch_bias_residual<__half>(__half*,
|
|||
int,
|
||||
int,
|
||||
cudaStream_t);
|
||||
|
||||
__global__ void gptj_residual_add(float* input,
|
||||
float* output,
|
||||
float* attn,
|
||||
float* bias,
|
||||
int total_count,
|
||||
int intermediate_size)
|
||||
{
|
||||
float4* input_cast = reinterpret_cast<float4*>(input);
|
||||
float4* output_cast = reinterpret_cast<float4*>(output);
|
||||
float4* attn_cast = reinterpret_cast<float4*>(attn);
|
||||
float4* bias_cast = reinterpret_cast<float4*>(bias);
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (offset < total_count) {
|
||||
float4 data = input_cast[offset];
|
||||
float4 out = output_cast[offset];
|
||||
float4 res_vec = attn_cast[offset];
|
||||
float4 bias_data = bias_cast[offset % intermediate_size];
|
||||
|
||||
data.x += (out.x + res_vec.x + bias_data.x);
|
||||
data.y += (out.y + res_vec.y + bias_data.y);
|
||||
data.z += (out.z + res_vec.z + bias_data.z);
|
||||
data.w += (out.w + res_vec.w + bias_data.w);
|
||||
|
||||
output_cast[offset] = data;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gptj_residual_add(__half* input,
|
||||
__half* output,
|
||||
__half* attn,
|
||||
__half* bias,
|
||||
int total_count,
|
||||
int intermediate_size)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
|
||||
float2* input_cast = reinterpret_cast<float2*>(input);
|
||||
float2* output_cast = reinterpret_cast<float2*>(output);
|
||||
float2* attn_cast = reinterpret_cast<float2*>(attn);
|
||||
|
||||
float2* bias_cast = reinterpret_cast<float2*>(bias);
|
||||
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (offset < total_count) {
|
||||
float2 vals_vec = input_cast[offset];
|
||||
float2 out_vec = output_cast[offset];
|
||||
float2 res_vec = attn_cast[offset];
|
||||
|
||||
float2 bias_vec = bias_cast[offset % intermediate_size];
|
||||
|
||||
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
|
||||
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
|
||||
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
|
||||
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
|
||||
|
||||
float2 low_data = __half22float2(vals_half[0]);
|
||||
float2 high_data = __half22float2(vals_half[1]);
|
||||
|
||||
float2 low_out = __half22float2(out_half[0]);
|
||||
float2 high_out = __half22float2(out_half[1]);
|
||||
|
||||
float2 low_res = __half22float2(res_half[0]);
|
||||
float2 high_res = __half22float2(res_half[1]);
|
||||
|
||||
float2 low_bias = __half22float2(bias_half[0]);
|
||||
float2 high_bias = __half22float2(bias_half[1]);
|
||||
|
||||
low_data.x += (low_out.x + low_res.x + low_bias.x);
|
||||
low_data.y += (low_out.y + low_res.y + low_bias.y);
|
||||
high_data.x += (high_out.x + high_res.x + high_bias.x);
|
||||
high_data.y += (high_out.y + high_res.y + high_bias.y);
|
||||
|
||||
vals_half[0] = __float22half2_rn(low_data);
|
||||
vals_half[1] = __float22half2_rn(high_data);
|
||||
|
||||
output_cast[offset] = vals_vec;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_gptj_residual_add(T* input,
|
||||
T* output,
|
||||
T* attn,
|
||||
T* bias,
|
||||
int hidden_dim,
|
||||
int batch,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int total_count = batch * hidden_dim / 4;
|
||||
dim3 block_dims(1024);
|
||||
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
|
||||
|
||||
gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
|
||||
input, output, attn, bias, total_count, hidden_dim / 4);
|
||||
}
|
||||
|
||||
template void
|
||||
launch_gptj_residual_add<float>(float*, float*, float*, float*, int, int, cudaStream_t);
|
||||
template void
|
||||
launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t);
|
||||
|
|
|
@ -204,16 +204,19 @@ at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& b
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void qkv_unfused_cublas(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
const float epsilon,
|
||||
bool add_bias)
|
||||
at::Tensor qkv_unfused_cublas(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
const float epsilon,
|
||||
bool add_bias)
|
||||
{
|
||||
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
|
||||
|
||||
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
|
||||
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
|
@ -236,16 +239,17 @@ void qkv_unfused_cublas(at::Tensor& output,
|
|||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
return inp_norm;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::Tensor ds_qkv_gemm(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
const float epsilon,
|
||||
bool add_bias)
|
||||
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& gamma,
|
||||
at::Tensor& beta,
|
||||
const float epsilon,
|
||||
bool add_bias)
|
||||
{
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
|
@ -256,9 +260,10 @@ at::Tensor ds_qkv_gemm(at::Tensor& input,
|
|||
|
||||
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
|
||||
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
|
||||
auto inp_norm =
|
||||
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
|
||||
|
||||
return output;
|
||||
return {output, inp_norm};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -592,6 +597,126 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
|
|||
return {output, residual_add};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
at::Tensor& bias,
|
||||
at::Tensor& weight_out)
|
||||
{
|
||||
// cudaStreamWaitEvent(
|
||||
// Context::Instance().GetCurrentStream(true), Context::Instance().GetCompEvent(1), 0);
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input_cont.options().dtype())
|
||||
.layout(at::kStrided)
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
auto intermediate =
|
||||
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
|
||||
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
|
||||
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(1),
|
||||
bsz,
|
||||
input.size(2),
|
||||
&alpha,
|
||||
&gemm_beta,
|
||||
(T*)weight.data_ptr(),
|
||||
(T*)input_cont.data_ptr(),
|
||||
(T*)intermediate.data_ptr(),
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
launch_bias_gelu((T*)intermediate.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight_out.size(1),
|
||||
bsz,
|
||||
intermediate.size(2),
|
||||
&alpha,
|
||||
&gemm_beta,
|
||||
(T*)weight_out.data_ptr(),
|
||||
(T*)intermediate.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
// cudaEventRecord(Context::Instance().GetCompEvent(2),
|
||||
// Context::Instance().GetCurrentStream(true));
|
||||
return output;
|
||||
}
|
||||
|
||||
void gptj_residual_add(at::Tensor& output,
|
||||
at::Tensor& input,
|
||||
at::Tensor& attention_output,
|
||||
at::Tensor& output_b)
|
||||
{
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
int hidden_size = input.size(2);
|
||||
// cudaStreamWaitEvent(
|
||||
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
|
||||
if (input.scalar_type() == at::kFloat)
|
||||
launch_gptj_residual_add<float>((float*)input.data_ptr(),
|
||||
(float*)output.data_ptr(),
|
||||
(float*)attention_output.data_ptr(),
|
||||
(float*)output_b.data_ptr(),
|
||||
hidden_size,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
else
|
||||
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
|
||||
(__half*)output.data_ptr(),
|
||||
(__half*)attention_output.data_ptr(),
|
||||
(__half*)output_b.data_ptr(),
|
||||
hidden_size,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
|
||||
at::Tensor& key_layer,
|
||||
unsigned rotary_dim,
|
||||
unsigned offset,
|
||||
unsigned num_heads)
|
||||
{
|
||||
auto query_cont = mixed_query.contiguous();
|
||||
auto key_cont = key_layer.contiguous();
|
||||
|
||||
unsigned bsz = mixed_query.size(0);
|
||||
unsigned head_size = mixed_query.size(2) / num_heads;
|
||||
unsigned seq_len = mixed_query.size(1);
|
||||
|
||||
if (mixed_query.scalar_type() == at::kFloat)
|
||||
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
|
||||
(float*)key_cont.data_ptr(),
|
||||
head_size,
|
||||
seq_len,
|
||||
rotary_dim,
|
||||
offset,
|
||||
num_heads,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
else
|
||||
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
|
||||
(__half*)key_cont.data_ptr(),
|
||||
head_size,
|
||||
seq_len,
|
||||
rotary_dim,
|
||||
offset,
|
||||
num_heads,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
return {query_cont, key_cont};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
|
||||
|
@ -627,4 +752,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||
m.def("linear_layer_int8",
|
||||
&ds_linear_layer_int8<__half>,
|
||||
"DeepSpeed linear_layer with int8 (CUDA)");
|
||||
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
|
||||
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
m.def("gptj_residual_add", &gptj_residual_add, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ inline int DS_GET_BLOCKS(const int N)
|
|||
|
||||
class Context {
|
||||
public:
|
||||
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
|
||||
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
|
||||
{
|
||||
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(_gen, 123);
|
||||
|
@ -50,12 +50,16 @@ public:
|
|||
throw std::runtime_error(message);
|
||||
}
|
||||
cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
|
||||
cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
virtual ~Context()
|
||||
{
|
||||
cublasDestroy(_cublasHandle);
|
||||
cudaFree(_workspace);
|
||||
cudaEventDestroy(_comp1_event);
|
||||
cudaEventDestroy(_comp2_event);
|
||||
}
|
||||
|
||||
static Context& Instance()
|
||||
|
@ -81,13 +85,19 @@ public:
|
|||
|
||||
curandGenerator_t& GetRandGenerator() { return _gen; }
|
||||
|
||||
cudaStream_t GetCurrentStream()
|
||||
cudaStream_t GetCurrentStream(bool other_stream = false)
|
||||
{
|
||||
// get current pytorch stream.
|
||||
if (other_stream) {
|
||||
if (!_stream) _stream = at::cuda::getStreamFromPool(true);
|
||||
return _stream;
|
||||
}
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
return stream;
|
||||
}
|
||||
|
||||
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
|
||||
|
||||
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
|
||||
|
||||
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
|
||||
|
@ -108,5 +118,11 @@ private:
|
|||
uint64_t _seed;
|
||||
uint64_t _curr_offset;
|
||||
size_t _workSpaceSize;
|
||||
|
||||
cudaEvent_t _comp1_event;
|
||||
cudaEvent_t _comp2_event;
|
||||
|
||||
cudaStream_t _stream;
|
||||
|
||||
std::vector<std::array<int, 3>> _gemm_algos;
|
||||
};
|
||||
|
|
|
@ -77,3 +77,22 @@ void launch_dequantize(T* output,
|
|||
unsigned groups,
|
||||
unsigned merge_count,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_gptj_residual_add(T* input,
|
||||
T* output,
|
||||
T* attn,
|
||||
T* bias,
|
||||
int batch,
|
||||
int head_size,
|
||||
cudaStream_t stream);
|
||||
template <typename T>
|
||||
void launch_apply_rotary_pos_emb(T* mixed_query,
|
||||
T* key_layer,
|
||||
unsigned head_size,
|
||||
unsigned seq_len,
|
||||
unsigned rotary_dim,
|
||||
unsigned offset,
|
||||
unsigned num_heads,
|
||||
unsigned batch,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -2,7 +2,7 @@ import copy
|
|||
import torch
|
||||
import deepspeed
|
||||
import deepspeed.ops.transformer as transformer_inference
|
||||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy
|
||||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy
|
||||
from .replace_policy import replace_policies
|
||||
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE
|
||||
from ..runtime.weight_quantizer import WeightQuantization
|
||||
|
@ -200,11 +200,12 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
_4hh_w = _4hh_w.half()
|
||||
|
||||
if quantize or fp16:
|
||||
dense_b = dense_b.half()
|
||||
qkvb = qkvb if qkvb is None else qkvb.half()
|
||||
dense_b = dense_b if dense_b is None else dense_b.half()
|
||||
_h4h_b = _h4h_b.half()
|
||||
_4hh_b = _4hh_b.half()
|
||||
attn_nw = attn_nw.half()
|
||||
attn_nb = attn_nb.half()
|
||||
attn_nw = attn_nw if dense_b is None else attn_nw.half()
|
||||
attn_nb = attn_nb if dense_b is None else attn_nb.half()
|
||||
input_nw = input_nw.half()
|
||||
input_nb = input_nb.half()
|
||||
|
||||
|
@ -216,7 +217,9 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
heads=num_attention_heads,
|
||||
layer_norm_eps=config.layer_norm_eps if hasattr(
|
||||
config,
|
||||
'layer_norm_eps') else 1e-12,
|
||||
'layer_norm_eps') else
|
||||
(config.layer_norm_epsilon if hasattr(config,
|
||||
'layer_norm_epsilon') else 1e-12),
|
||||
fp16=fp16,
|
||||
pre_layer_norm=preln,
|
||||
mp_size=mp_size,
|
||||
|
@ -227,7 +230,10 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
if hasattr(config,
|
||||
'attention_layers') else False),
|
||||
window_size=(config.window_size if hasattr(config,
|
||||
'window_size') else 1))
|
||||
'window_size') else 1),
|
||||
rotary_dim=(config.rotary_dim if hasattr(config,
|
||||
'rotary_dim') else -1),
|
||||
mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy))
|
||||
|
||||
if quantize and quantize_settings is not None:
|
||||
(quantization_scales,
|
||||
|
@ -278,18 +284,10 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
attn_block = new_module.attention
|
||||
attn_block.attn_qkvw.data = mp_replace.qkv_copy(attn_block.attn_qkvw.data,
|
||||
qkvw)
|
||||
|
||||
if qkvb is not None:
|
||||
if fp16:
|
||||
qkvb = qkvb.half()
|
||||
attn_block.attn_qkvb.data = mp_replace.qkv_copy(
|
||||
attn_block.attn_qkvb.data,
|
||||
qkvb)
|
||||
else:
|
||||
attn_block.attn_qkvb = qkvb
|
||||
attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb.data, qkvb)
|
||||
|
||||
attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w)
|
||||
attn_block.attn_ob.data = mp_replace.copy(attn_block.attn_ob.data, dense_b)
|
||||
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob.data, dense_b)
|
||||
|
||||
mpl_block = new_module.mlp
|
||||
mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data, _h4h_w)
|
||||
|
@ -297,8 +295,10 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data, _4hh_w)
|
||||
mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data, _4hh_b)
|
||||
|
||||
new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device())
|
||||
new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device())
|
||||
new_module.mlp.attn_nw = attn_nw if dense_b is None else attn_nw.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.mlp.attn_nb = attn_nb if dense_b is None else attn_nb.to(
|
||||
torch.cuda.current_device())
|
||||
new_module.norm_w.data = input_nw.to(torch.cuda.current_device())
|
||||
new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
|
||||
else:
|
||||
|
|
|
@ -143,6 +143,50 @@ class HFGPTNEOLayerPolicy(DSPolicy):
|
|||
self.client_module.ln_1.bias.data
|
||||
|
||||
|
||||
class HFGPTJLayerPolicy(DSPolicy):
|
||||
_orig_layer_class = None
|
||||
|
||||
def __init__(self, client_module, inference=True):
|
||||
super().__init__(inference, scale_attention=True)
|
||||
self.client_module = client_module
|
||||
try:
|
||||
import transformers
|
||||
HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock
|
||||
except:
|
||||
HFGPTJLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.q_proj.weight.data.shape[1], \
|
||||
self.client_module.attn.num_attention_heads
|
||||
|
||||
def attention(self):
|
||||
qw = self.client_module.attn.q_proj.weight.data
|
||||
kw = self.client_module.attn.k_proj.weight.data
|
||||
vw = self.client_module.attn.v_proj.weight.data
|
||||
|
||||
qkvw = torch.cat((qw, kw, vw), dim=0)
|
||||
|
||||
return self.linear_layer, \
|
||||
qkvw, \
|
||||
None, \
|
||||
self.client_module.attn.out_proj.weight.data, \
|
||||
None, \
|
||||
self.scale_attention
|
||||
|
||||
def mlp(self):
|
||||
return self.linear_layer, \
|
||||
self.client_module.mlp.fc_in.weight.data, \
|
||||
self.client_module.mlp.fc_in.bias.data, \
|
||||
self.client_module.mlp.fc_out.weight.data, \
|
||||
self.client_module.mlp.fc_out.bias.data
|
||||
|
||||
def layerNorm(self):
|
||||
return None, \
|
||||
None, \
|
||||
self.client_module.ln_1.weight.data, \
|
||||
self.client_module.ln_1.bias.data
|
||||
|
||||
|
||||
class MegatronLayerPolicy(DSPolicy):
|
||||
_orig_layer_class = None
|
||||
|
||||
|
@ -234,6 +278,7 @@ class HFGPT2LayerPolicy(DSPolicy):
|
|||
replace_policies = [
|
||||
HFBertLayerPolicy,
|
||||
HFGPTNEOLayerPolicy,
|
||||
HFGPTJLayerPolicy,
|
||||
MegatronLayerPolicy,
|
||||
HFGPT2LayerPolicy,
|
||||
]
|
||||
|
|
|
@ -226,19 +226,21 @@ def top1gating(logits: Tensor,
|
|||
exp_selection_uniform_map[logits.device] = uniform
|
||||
|
||||
mask1_rand = mask1 * uniform(mask1.shape)
|
||||
else:
|
||||
mask1_rand = mask1
|
||||
|
||||
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
|
||||
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
|
||||
|
||||
top_idx = _top_idx(mask1_rand, capacity)
|
||||
top_idx = _top_idx(mask1_rand, capacity)
|
||||
|
||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||
mask1 = new_mask1
|
||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||
mask1 = new_mask1
|
||||
|
||||
if use_tutel:
|
||||
# Tutel doesn't support index values masked with zero
|
||||
# so we need to replace masked indices with -1
|
||||
indices_mask = mask1.sum(dim=1) * num_experts - 1
|
||||
indices1_s = torch.min(indices1_s, indices_mask)
|
||||
if use_tutel:
|
||||
# Tutel doesn't support index values masked with zero
|
||||
# so we need to replace masked indices with -1
|
||||
indices_mask = mask1.sum(dim=1) * num_experts - 1
|
||||
indices1_s = torch.min(indices1_s, indices_mask)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
if use_tutel:
|
||||
|
|
|
@ -66,7 +66,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
triangular_masking=True,
|
||||
local_attention=False,
|
||||
window_size=256,
|
||||
return_tuple=True):
|
||||
rotary_dim=-1,
|
||||
return_tuple=True,
|
||||
mlp_after_attn=True):
|
||||
super(DeepSpeedInferenceConfig,
|
||||
self).__init__(
|
||||
hidden_size,
|
||||
|
@ -85,7 +87,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
self.triangular_masking = triangular_masking
|
||||
self.local_attention = local_attention
|
||||
self.window_size = window_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.return_tuple = return_tuple
|
||||
self.mlp_after_attn = mlp_after_attn
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
|
@ -180,6 +184,14 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
unfused_mode = not config.specialized_mode or \
|
||||
mixed_query.shape[1] >= 32 or head_size > 128
|
||||
|
||||
if config.rotary_dim > 0:
|
||||
mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb(
|
||||
mixed_query,
|
||||
key_layer,
|
||||
config.rotary_dim,
|
||||
0 if layer_past is None else layer_past[0].shape[-2],
|
||||
num_attention_heads_per_partition)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
if unfused_mode:
|
||||
|
@ -189,21 +201,23 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer),
|
||||
dim=-2)
|
||||
presents = (key_layer, value_layer)
|
||||
|
||||
if unfused_mode:
|
||||
mixed_query = _transpose_for_scores(mixed_query, False, True)
|
||||
key_layer1 = _transpose_for_scores(
|
||||
key_layer = _transpose_for_scores(
|
||||
key_layer,
|
||||
True,
|
||||
True) / (norm_factor if config.scale_attention else 1.0)
|
||||
value_layer1 = _transpose_for_scores(value_layer, False, True)
|
||||
value_layer = _transpose_for_scores(value_layer, False, True)
|
||||
|
||||
if layer_past is None:
|
||||
attn_key_value = score_context_func(
|
||||
mixed_query,
|
||||
(key_layer1 if unfused_mode else key_layer),
|
||||
key_layer,
|
||||
torch.empty(1),
|
||||
(input_mask),
|
||||
(value_layer1 if unfused_mode else value_layer),
|
||||
input_mask,
|
||||
value_layer,
|
||||
torch.empty(1),
|
||||
num_attention_heads_per_partition,
|
||||
(1 / norm_factor if config.scale_attention else 1.0),
|
||||
|
@ -215,11 +229,11 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
else:
|
||||
attn_key_value = score_context_func(
|
||||
mixed_query,
|
||||
(key_layer1 if unfused_mode else past_key.type_as(key_layer)),
|
||||
(key_layer1 if unfused_mode else key_layer),
|
||||
(input_mask),
|
||||
(value_layer1 if unfused_mode else past_value.type_as(value_layer)),
|
||||
(value_layer1 if unfused_mode else value_layer),
|
||||
(key_layer if unfused_mode else past_key.type_as(key_layer)),
|
||||
key_layer,
|
||||
input_mask,
|
||||
(value_layer if unfused_mode else past_value.type_as(value_layer)),
|
||||
value_layer,
|
||||
num_attention_heads_per_partition,
|
||||
(1 / norm_factor if config.scale_attention else 1.0),
|
||||
(not unfused_mode),
|
||||
|
@ -235,7 +249,7 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
# Transpose Context
|
||||
context_layer = _transpose_for_context(context_layer)
|
||||
|
||||
return context_layer, key_layer, value_layer
|
||||
return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer
|
||||
|
||||
def selfAttention_fp():
|
||||
vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
|
||||
|
@ -255,10 +269,10 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
norm_b,
|
||||
config.epsilon,
|
||||
(attn_qkvb is not None))
|
||||
context_layer, key_layer, value_layer = compute_attention(qkv_out, input_mask)
|
||||
context_layer, key_layer, value_layer = compute_attention(qkv_out[0], input_mask)
|
||||
output = vector_matmul_func(context_layer, attn_ow)
|
||||
|
||||
return output, key_layer, value_layer, context_layer
|
||||
return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm
|
||||
|
||||
def selfAttention_int8():
|
||||
if not config.pre_layer_norm:
|
||||
|
@ -290,12 +304,12 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
if config.q_int8:
|
||||
output, key_layer, value_layer, context_layer = selfAttention_int8()
|
||||
else:
|
||||
output, key_layer, value_layer, context_layer = selfAttention_fp()
|
||||
output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
|
||||
|
||||
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
|
||||
torch.distributed.all_reduce(output, group=mp_group)
|
||||
|
||||
return (output, key_layer, value_layer, context_layer)
|
||||
return (output, key_layer, value_layer, context_layer, inp_norm)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_output1, grad_output2, grad_output3):
|
||||
|
@ -396,9 +410,12 @@ class DeepSpeedMLPFunction(Function):
|
|||
output_w,
|
||||
q_scales,
|
||||
q_groups,
|
||||
merge_count):
|
||||
merge_count,
|
||||
mlp_gemm_func,
|
||||
fused_gemm_gelu,
|
||||
vector_matmul_func,
|
||||
bias_residual_func):
|
||||
if config.q_int8:
|
||||
|
||||
(intermediate,
|
||||
residual_add) = inference_cuda_module.mlp_gemm_int8(
|
||||
input,
|
||||
|
@ -418,29 +435,26 @@ class DeepSpeedMLPFunction(Function):
|
|||
q_groups,
|
||||
(merge_count))
|
||||
else:
|
||||
mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \
|
||||
inference_cuda_module.mlp_gemm_fp32
|
||||
vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
|
||||
inference_cuda_module.vector_matmul_fp32
|
||||
(intermediate,
|
||||
residual_add) = mlp_gemm_func(input,
|
||||
residual,
|
||||
bias,
|
||||
inter_w,
|
||||
inter_b,
|
||||
attn_nw,
|
||||
attn_nb,
|
||||
config.epsilon,
|
||||
config.pre_layer_norm)
|
||||
output = vector_matmul_func(intermediate, output_w)
|
||||
if attn_nw is None:
|
||||
output = fused_gemm_gelu(input, inter_w, inter_b, output_w)
|
||||
else:
|
||||
(intermediate,
|
||||
residual_add) = mlp_gemm_func(input,
|
||||
residual,
|
||||
bias,
|
||||
inter_w,
|
||||
inter_b,
|
||||
attn_nw,
|
||||
attn_nb,
|
||||
config.epsilon,
|
||||
config.pre_layer_norm)
|
||||
output = vector_matmul_func(intermediate, output_w)
|
||||
|
||||
if mp_group is not None and torch.distributed.get_world_size(group=mp_group) > 1:
|
||||
torch.distributed.all_reduce(output, group=mp_group)
|
||||
|
||||
bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
|
||||
inference_cuda_module.bias_residual_fp32
|
||||
|
||||
output = bias_residual_func(output, residual_add, output_b)
|
||||
if attn_nw is not None:
|
||||
output = bias_residual_func(output, residual_add, output_b)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -479,6 +493,15 @@ class DeepSpeedMLP(nn.Module):
|
|||
self.merge_count = int(math.log2(merge_count))
|
||||
|
||||
self.mp_group = mp_group
|
||||
self.mlp_gemm_func = inference_cuda_module.mlp_gemm_fp16 if config.fp16 else \
|
||||
inference_cuda_module.mlp_gemm_fp32
|
||||
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
|
||||
inference_cuda_module.vector_matmul_fp32
|
||||
self.fused_gemm_gelu = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
|
||||
inference_cuda_module.fused_gemm_gelu_fp32
|
||||
|
||||
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
|
||||
inference_cuda_module.bias_residual_fp32
|
||||
|
||||
def forward(self, input, residual, bias):
|
||||
return DeepSpeedMLPFunction.apply(input,
|
||||
|
@ -494,7 +517,11 @@ class DeepSpeedMLP(nn.Module):
|
|||
self.output_w,
|
||||
self.q_scales,
|
||||
self.q_groups,
|
||||
self.merge_count)
|
||||
self.merge_count,
|
||||
self.mlp_gemm_func,
|
||||
self.fused_gemm_gelu,
|
||||
self.vector_matmul_func,
|
||||
self.bias_residual_func)
|
||||
|
||||
|
||||
class DeepSpeedTransformerInference(nn.Module):
|
||||
|
@ -528,21 +555,6 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.config = config
|
||||
self.config.layer_id = DeepSpeedTransformerInference.layer_id
|
||||
DeepSpeedTransformerInference.layer_id += 1
|
||||
self.attention = DeepSpeedSelfAttention(self.config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
qkv_merging)
|
||||
self.mlp = DeepSpeedMLP(self.config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
mlp_extra_grouping)
|
||||
|
||||
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
|
||||
global inference_cuda_module
|
||||
global specialized_mode
|
||||
|
@ -560,6 +572,22 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.config.specialized_mode = specialized_mode
|
||||
print("DeepSpeed Transformer Inference config is ", self.config.__dict__)
|
||||
|
||||
self.attention = DeepSpeedSelfAttention(self.config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
qkv_merging)
|
||||
self.mlp = DeepSpeedMLP(self.config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
quantize_groups,
|
||||
merge_count,
|
||||
mlp_extra_grouping)
|
||||
|
||||
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
input_mask=None,
|
||||
|
@ -596,13 +624,15 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.norm_b)
|
||||
|
||||
if get_present:
|
||||
attention_output, p_key, p_value, _ = attention_output
|
||||
presents = (p_key, p_value)
|
||||
presents = (attention_output[1], attention_output[2])
|
||||
elif output_attentions:
|
||||
attention_output, _, _, context_output = attention_output
|
||||
else:
|
||||
attention_output, _, _, _ = attention_output
|
||||
output = self.mlp(attention_output, input, self.attention.attn_ob)
|
||||
context_output = attention_output[3]
|
||||
|
||||
output = self.mlp(
|
||||
attention_output[0]
|
||||
if self.config.mlp_after_attn else attention_output[-1],
|
||||
input,
|
||||
self.attention.attn_ob)
|
||||
|
||||
if not self.config.pre_layer_norm:
|
||||
ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \
|
||||
|
@ -612,8 +642,13 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.norm_b,
|
||||
self.config.epsilon)
|
||||
|
||||
if input_type != output.dtype:
|
||||
output = output.to(input_type)
|
||||
if not self.config.mlp_after_attn:
|
||||
inference_cuda_module.gptj_residual_add(output,
|
||||
input,
|
||||
attention_output[0],
|
||||
self.mlp.output_b)
|
||||
|
||||
output = output.to(input_type)
|
||||
|
||||
if get_present:
|
||||
output = (output, presents)
|
||||
|
|
|
@ -19,6 +19,7 @@ class InferenceBuilder(CUDAOpBuilder):
|
|||
'csrc/transformer/inference/csrc/normalize.cu',
|
||||
'csrc/transformer/inference/csrc/softmax.cu',
|
||||
'csrc/transformer/inference/csrc/dequantize.cu',
|
||||
'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче