зеркало из https://github.com/microsoft/DeepSpeed.git
Add local attention for GPT-Neo model architecture (#1114)
* fix links for inference tutorial * Fix automatic injection. Add the local-attention for GPT-Neo * fix the inference for generation of large sequences (>1K & <32K) * fix format Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Родитель
8def3cb3a2
Коммит
aca7fc549a
|
@ -13,7 +13,9 @@ template <typename T>
|
|||
at::Tensor ds_softmax(at::Tensor& attn_scores,
|
||||
at::Tensor& attn_mask,
|
||||
bool triangular,
|
||||
bool recompute)
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size)
|
||||
{
|
||||
auto attn_scores_c = attn_scores.contiguous();
|
||||
int bsz = attn_scores_c.size(0);
|
||||
|
@ -25,6 +27,8 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
|
|||
(T*)attn_mask.data_ptr(),
|
||||
triangular,
|
||||
recompute,
|
||||
local_attention,
|
||||
window_size,
|
||||
bsz,
|
||||
heads,
|
||||
seq_len,
|
||||
|
@ -47,7 +51,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
int& heads,
|
||||
float& norm_factor,
|
||||
bool triangular,
|
||||
bool recompute)
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size)
|
||||
{
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(query_cont.options().dtype())
|
||||
|
@ -75,7 +81,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
seq_len * soft_len,
|
||||
bsz * heads,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
attn_score = ds_softmax<T>(attn_score, attn_mask, triangular, recompute);
|
||||
attn_score =
|
||||
ds_softmax<T>(attn_score, attn_mask, triangular, recompute, local_attention, window_size);
|
||||
alpha = 1.0;
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
k,
|
||||
|
@ -105,7 +112,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
|
|||
int heads,
|
||||
float norm_factor,
|
||||
bool merging,
|
||||
bool triangular)
|
||||
bool triangular,
|
||||
bool local_attention,
|
||||
int window_size)
|
||||
{
|
||||
auto query_cont = query.contiguous();
|
||||
auto prev_key_cont = prev_key.contiguous();
|
||||
|
@ -138,7 +147,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
|
|||
heads,
|
||||
norm_factor,
|
||||
(triangular && (new_size == 0)),
|
||||
(new_size == 0));
|
||||
(new_size == 0),
|
||||
local_attention,
|
||||
window_size);
|
||||
|
||||
return {output, prev_key, prev_value};
|
||||
}
|
||||
|
|
|
@ -6,10 +6,8 @@
|
|||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
|
||||
#define Attn_Threads 128
|
||||
#define Reduce_Threads 32
|
||||
#define attn_warps 4
|
||||
#define MAX_ATTN_REG 4 // MAX Head Size 256
|
||||
#define ATTN_THREADS 1024
|
||||
#define MAX_REG_SIZE 8
|
||||
|
||||
#define minus_infinity (-1 * std::numeric_limits<float>::infinity())
|
||||
|
||||
|
@ -26,31 +24,40 @@ void CheckCudaErrorAux(const char* file, unsigned line)
|
|||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int tbSize, int tbSeq>
|
||||
__global__ void attn_softmax_v2(__half* vals,
|
||||
__half* mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int total_count,
|
||||
int heads,
|
||||
int sequence_length,
|
||||
int num_seq,
|
||||
float scale)
|
||||
float scale,
|
||||
int iterations,
|
||||
int reduceWidth)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
float2 low_data[tbSeq];
|
||||
float2 high_data[tbSeq];
|
||||
float2 low_data[MAX_REG_SIZE];
|
||||
float2 high_data[MAX_REG_SIZE];
|
||||
|
||||
__half2 h_scale = __float2half2_rn(scale);
|
||||
|
||||
int wid = threadIdx.x >> 5;
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int warp_num = blockDim.x >> 5;
|
||||
|
||||
int iter_offset = blockIdx.x * (blockDim.x >> 5) + wid;
|
||||
int reduce_blocks = reduceWidth >> 5;
|
||||
int seq_lane = threadIdx.x % reduceWidth;
|
||||
|
||||
__shared__ float partialSum[MAX_WARP_NUM];
|
||||
|
||||
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
|
||||
|
||||
if (iter_offset < total_count) {
|
||||
vals += (iter_offset * sequence_length);
|
||||
|
@ -58,20 +65,33 @@ __global__ void attn_softmax_v2(__half* vals,
|
|||
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
|
||||
int seq_id = iter_offset % num_seq;
|
||||
int seq_id4 = seq_id >> 2;
|
||||
|
||||
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
|
||||
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
|
||||
? (real_seq_id >> 2) - (window_size >> 2)
|
||||
: 0;
|
||||
int window_stride =
|
||||
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
|
||||
|
||||
float max_val = minus_infinity;
|
||||
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
|
||||
if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) {
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
|
||||
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
|
||||
data_id < sequence_length) {
|
||||
if ((sequence_length - data_id) >= 4) {
|
||||
low_data[i].x = __half2float(vals[data_id]);
|
||||
low_data[i].y = (!triangular || ((data_id + 1) <= seq_id))
|
||||
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
|
||||
: minus_infinity;
|
||||
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
|
||||
(data_id + 1) > window_stride)
|
||||
? __half2float(vals[data_id + 1])
|
||||
: minus_infinity;
|
||||
high_data[i].x = (!triangular || ((data_id + 2) <= seq_id))
|
||||
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
|
||||
(data_id + 2) > window_stride)
|
||||
? __half2float(vals[data_id + 2])
|
||||
: minus_infinity;
|
||||
high_data[i].y = (!triangular || ((data_id + 3) <= seq_id))
|
||||
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
|
||||
(data_id + 3) > window_stride)
|
||||
? __half2float(vals[data_id + 3])
|
||||
: minus_infinity;
|
||||
if (mask && !triangular && recompute) {
|
||||
|
@ -81,12 +101,15 @@ __global__ void attn_softmax_v2(__half* vals,
|
|||
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
|
||||
}
|
||||
} else {
|
||||
low_data[i].x = __half2float(vals[data_id]);
|
||||
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
|
||||
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
|
||||
: minus_infinity;
|
||||
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
|
||||
(data_id + 1) > window_stride) &&
|
||||
(data_id + 1) < sequence_length)
|
||||
? __half2float(vals[data_id + 1])
|
||||
: minus_infinity;
|
||||
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id)) &&
|
||||
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
|
||||
(data_id + 2) > window_stride) &&
|
||||
(data_id + 2) < sequence_length)
|
||||
? __half2float(vals[data_id + 2])
|
||||
: minus_infinity;
|
||||
|
@ -112,13 +135,29 @@ __global__ void attn_softmax_v2(__half* vals,
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < tbSize; i *= 2) {
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2) {
|
||||
auto temp = g.shfl_xor(max_val, i);
|
||||
max_val = (temp > max_val ? temp : max_val);
|
||||
}
|
||||
|
||||
if (reduceWidth > WARP_SIZE) {
|
||||
if (lane == 0) partialSum[wid] = max_val;
|
||||
b.sync();
|
||||
|
||||
if (lane < warp_num) max_val = partialSum[lane];
|
||||
|
||||
b.sync();
|
||||
|
||||
for (int i = 1; i < reduce_blocks; i *= 2) {
|
||||
auto temp = g.shfl_xor(max_val, i);
|
||||
max_val = (temp > max_val ? temp : max_val);
|
||||
}
|
||||
|
||||
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
|
||||
}
|
||||
|
||||
float sum = 0;
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
low_data[i].x = __expf(low_data[i].x - max_val);
|
||||
low_data[i].y = __expf(low_data[i].y - max_val);
|
||||
high_data[i].x = __expf(high_data[i].x - max_val);
|
||||
|
@ -127,12 +166,24 @@ __global__ void attn_softmax_v2(__half* vals,
|
|||
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
|
||||
}
|
||||
|
||||
for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i);
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
|
||||
|
||||
if (reduceWidth > WARP_SIZE) {
|
||||
if (lane == 0) partialSum[wid] = sum;
|
||||
b.sync();
|
||||
|
||||
if (lane < warp_num) sum = partialSum[lane];
|
||||
|
||||
b.sync();
|
||||
|
||||
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
|
||||
|
||||
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
|
||||
}
|
||||
sum += 1e-6;
|
||||
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
|
||||
|
||||
if (data_id < sequence_length) {
|
||||
if ((sequence_length - data_id) >= 4) {
|
||||
|
@ -151,46 +202,69 @@ __global__ void attn_softmax_v2(__half* vals,
|
|||
#endif
|
||||
}
|
||||
|
||||
template <int tbSize, int tbSeq>
|
||||
__global__ void attn_softmax_v2(float* vals,
|
||||
float* attn_mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int total_count,
|
||||
int heads,
|
||||
int sequence_length,
|
||||
int num_seq,
|
||||
float scale)
|
||||
float scale,
|
||||
int iterations,
|
||||
int reduceWidth)
|
||||
{
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
float4 data[tbSeq];
|
||||
float4 data[MAX_REG_SIZE];
|
||||
|
||||
int wid = threadIdx.x >> 5;
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int warp_num = blockDim.x >> 5;
|
||||
|
||||
int iter_offset = blockIdx.x * warp_num + wid;
|
||||
int reduce_blocks = reduceWidth >> 5;
|
||||
int seq_lane = threadIdx.x % reduceWidth;
|
||||
|
||||
__shared__ float partialSum[MAX_WARP_NUM];
|
||||
|
||||
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
|
||||
if (iter_offset < total_count) {
|
||||
vals += (iter_offset * sequence_length);
|
||||
|
||||
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
|
||||
int seq_id = iter_offset % num_seq;
|
||||
int seq_id4 = seq_id >> 2;
|
||||
|
||||
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
|
||||
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
|
||||
? (real_seq_id >> 2) - (window_size >> 2)
|
||||
: 0;
|
||||
int window_stride =
|
||||
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
|
||||
|
||||
float max_val = minus_infinity;
|
||||
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
|
||||
if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) {
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
|
||||
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
|
||||
data_id < sequence_length) {
|
||||
if ((sequence_length - data_id) >= 4) {
|
||||
data[i].x = (vals[data_id]);
|
||||
data[i].y = (!triangular || ((data_id + 1) <= seq_id)) ? (vals[data_id + 1])
|
||||
: minus_infinity;
|
||||
data[i].z = (!triangular || ((data_id + 2) <= seq_id)) ? (vals[data_id + 2])
|
||||
: minus_infinity;
|
||||
data[i].w = (!triangular || ((data_id + 3) <= seq_id)) ? (vals[data_id + 3])
|
||||
: minus_infinity;
|
||||
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
|
||||
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
|
||||
(data_id + 1) > window_stride)
|
||||
? vals[data_id + 1]
|
||||
: minus_infinity;
|
||||
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
|
||||
(data_id + 2) > window_stride)
|
||||
? vals[data_id + 2]
|
||||
: minus_infinity;
|
||||
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
|
||||
(data_id + 3) > window_stride)
|
||||
? vals[data_id + 3]
|
||||
: minus_infinity;
|
||||
if (attn_mask && !triangular && recompute) {
|
||||
data[i].x += attn_mask[data_id + mask_offset];
|
||||
data[i].y += attn_mask[data_id + mask_offset + 1];
|
||||
|
@ -198,13 +272,13 @@ __global__ void attn_softmax_v2(float* vals,
|
|||
data[i].w += attn_mask[data_id + mask_offset + 3];
|
||||
}
|
||||
} else {
|
||||
data[i].x = (vals[data_id]);
|
||||
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
|
||||
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
|
||||
(data_id + 1) < sequence_length)
|
||||
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
|
||||
? (vals[data_id + 1])
|
||||
: minus_infinity;
|
||||
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
|
||||
(data_id + 2) < sequence_length)
|
||||
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
|
||||
? (vals[data_id + 2])
|
||||
: minus_infinity;
|
||||
data[i].w = minus_infinity;
|
||||
|
@ -228,13 +302,29 @@ __global__ void attn_softmax_v2(float* vals,
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < tbSize; i *= 2) {
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2) {
|
||||
auto temp = g.shfl_xor(max_val, i);
|
||||
max_val = (temp > max_val ? temp : max_val);
|
||||
}
|
||||
|
||||
if (reduceWidth > WARP_SIZE) {
|
||||
if (lane == 0) partialSum[wid] = max_val;
|
||||
b.sync();
|
||||
|
||||
if (lane < warp_num) max_val = partialSum[lane];
|
||||
|
||||
b.sync();
|
||||
|
||||
for (int i = 1; i < reduce_blocks; i *= 2) {
|
||||
auto temp = g.shfl_xor(max_val, i);
|
||||
max_val = (temp > max_val ? temp : max_val);
|
||||
}
|
||||
|
||||
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
|
||||
}
|
||||
|
||||
float sum = 0;
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
data[i].x = __expf(data[i].x - max_val);
|
||||
data[i].y = __expf(data[i].y - max_val);
|
||||
data[i].z = __expf(data[i].z - max_val);
|
||||
|
@ -243,12 +333,24 @@ __global__ void attn_softmax_v2(float* vals,
|
|||
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
|
||||
}
|
||||
|
||||
for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i);
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
|
||||
|
||||
if (reduceWidth > WARP_SIZE) {
|
||||
if (lane == 0) partialSum[wid] = sum;
|
||||
b.sync();
|
||||
|
||||
if (lane < warp_num) sum = partialSum[lane];
|
||||
|
||||
b.sync();
|
||||
|
||||
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
|
||||
|
||||
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
|
||||
}
|
||||
sum += 1e-6;
|
||||
|
||||
for (int i = 0; i < tbSeq; i++) {
|
||||
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
|
||||
|
||||
if (data_id < sequence_length) {
|
||||
if ((sequence_length - data_id) >= 4) {
|
||||
|
@ -271,6 +373,8 @@ void launch_attn_softmax_v2(T* vals,
|
|||
T* mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int batch_size,
|
||||
int heads,
|
||||
int num_seq,
|
||||
|
@ -279,36 +383,36 @@ void launch_attn_softmax_v2(T* vals,
|
|||
cudaStream_t stream)
|
||||
{
|
||||
int total_count = batch_size * heads * num_seq;
|
||||
dim3 grid_dim((total_count - 1) / 32 + 1);
|
||||
dim3 block_dim(1024);
|
||||
if (sequence_length <= 128)
|
||||
attn_softmax_v2<32, 1><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
else if (sequence_length <= 256)
|
||||
attn_softmax_v2<32, 2><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
else if (sequence_length <= 512)
|
||||
attn_softmax_v2<32, 4><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
else if (sequence_length <= 1024)
|
||||
attn_softmax_v2<32, 8><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
else if (sequence_length <= 2048)
|
||||
attn_softmax_v2<32, 16><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
else if (sequence_length <= 4096)
|
||||
attn_softmax_v2<32, 32><<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
|
||||
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
|
||||
dim3 block_dim(ATTN_THREADS);
|
||||
|
||||
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
|
||||
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
|
||||
|
||||
if (sequence_length <= 32768)
|
||||
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(vals,
|
||||
mask,
|
||||
triangular,
|
||||
recompute,
|
||||
local_attention,
|
||||
window_size,
|
||||
total_count,
|
||||
heads,
|
||||
sequence_length,
|
||||
num_seq,
|
||||
scale,
|
||||
iterations,
|
||||
reduce_width);
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"Unsupport Seq_Length! Check the restriction of the max_threads and "
|
||||
"max_thread_iterations!");
|
||||
throw std::runtime_error("Unsupport Seq_Length!");
|
||||
}
|
||||
|
||||
template void launch_attn_softmax_v2(float* vals,
|
||||
float* mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int batch_size,
|
||||
int heads,
|
||||
int num_seq,
|
||||
|
@ -319,6 +423,8 @@ template void launch_attn_softmax_v2(__half* vals,
|
|||
__half* mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int batch_size,
|
||||
int heads,
|
||||
int num_seq,
|
||||
|
|
|
@ -18,6 +18,8 @@ void launch_attn_softmax_v2(T* vals,
|
|||
T* mask,
|
||||
bool triangular,
|
||||
bool recompute,
|
||||
bool local_attention,
|
||||
int window_size,
|
||||
int batch_size,
|
||||
int heads,
|
||||
int num_seq,
|
||||
|
|
|
@ -78,8 +78,12 @@ class InferenceEngine(Module):
|
|||
def _get_model_config_generate(self):
|
||||
if hasattr(self.module, 'config'):
|
||||
self.config = self.module.config
|
||||
else:
|
||||
self.config = None
|
||||
if hasattr(self.module, 'generate'):
|
||||
self.generate = self.module.generate
|
||||
else:
|
||||
self.generate = None
|
||||
|
||||
def _create_model_parallel_group(self):
|
||||
# Call the init process
|
||||
|
@ -134,6 +138,7 @@ class InferenceEngine(Module):
|
|||
policy=injection_policy,
|
||||
mp_size=self.mp_world_size,
|
||||
mp_group=self.mp_group,
|
||||
config=self.config,
|
||||
fp16=(self.dtype == torch.half),
|
||||
training=False,
|
||||
quantize=(self.dtype == torch.int8),
|
||||
|
|
|
@ -90,7 +90,7 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
model,
|
||||
policy=None,
|
||||
micro_batch_size=-1,
|
||||
bert_config=None,
|
||||
config=None,
|
||||
seed=-1,
|
||||
hidden_size=-1,
|
||||
num_attention_heads=-1,
|
||||
|
@ -110,7 +110,7 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
model (torch.nn.Module): user's nn.module representing their model
|
||||
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters
|
||||
micro_batch_size (int): micro batch size per gpu used during training/eval
|
||||
bert_config (dict): model config containing hidden size, attention heads, etc.
|
||||
config (dict): model config containing hidden size, attention heads, etc.
|
||||
seed (int): random seed value
|
||||
max_seq_length (int): max sequence length for training
|
||||
hidden_size (int): hidden dimension
|
||||
|
@ -155,7 +155,6 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
_4hh_w = _4hh_w.half()
|
||||
|
||||
if quantize or fp16:
|
||||
qkvb = qkvb.half()
|
||||
dense_b = dense_b.half()
|
||||
_h4h_b = _h4h_b.half()
|
||||
_4hh_b = _4hh_b.half()
|
||||
|
@ -175,7 +174,12 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
mp_size=mp_size,
|
||||
q_int8=quantize,
|
||||
encoder_decoder=(True if policy_cls is HFBertLayerPolicy else False),
|
||||
triangular_masking=(policy_cls is not HFBertLayerPolicy))
|
||||
triangular_masking=(policy_cls is not HFBertLayerPolicy),
|
||||
local_attention=((config.attention_layers[layer_id] == "local")
|
||||
if hasattr(config,
|
||||
'attention_layers') else False),
|
||||
window_size=(config.window_size if hasattr(config,
|
||||
'window_size') else 1))
|
||||
|
||||
if quantize and quantize_settings is not None:
|
||||
(quantization_scales,
|
||||
|
@ -208,8 +212,8 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
new_module.config.scale_attention = scale_attention
|
||||
|
||||
# we want the weights in [input, output] shape
|
||||
# linear layer is created is created with [input, output] shape
|
||||
# we transpose it here to reduce inference cost!
|
||||
# linear layer is created with [input, output] shape
|
||||
# transpose it here to reduce inference cost!
|
||||
def transpose(data):
|
||||
data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1))
|
||||
data = data.reshape(data.shape[-1], data.shape[-2])
|
||||
|
@ -228,6 +232,7 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
qkvw)
|
||||
|
||||
if qkvb is not None:
|
||||
qkvb = qkvb.half()
|
||||
attn_block.attn_qkvb.data = mp_replace.qkv_copy(
|
||||
attn_block.attn_qkvb.data,
|
||||
qkvb)
|
||||
|
@ -250,12 +255,12 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
else:
|
||||
transformer_config = deepspeed.DeepSpeedTransformerConfig(
|
||||
batch_size=micro_batch_size,
|
||||
hidden_size=bert_config.hidden_size,
|
||||
heads=bert_config.num_attention_heads,
|
||||
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
|
||||
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
|
||||
num_hidden_layers=bert_config.num_hidden_layers,
|
||||
initializer_range=bert_config.initializer_range,
|
||||
hidden_size=config.hidden_size,
|
||||
heads=config.num_attention_heads,
|
||||
attn_dropout_ratio=config.attention_probs_dropout_prob,
|
||||
hidden_dropout_ratio=config.hidden_dropout_prob,
|
||||
num_hidden_layers=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range,
|
||||
seed=seed,
|
||||
fp16=fp16,
|
||||
pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln),
|
||||
|
@ -302,20 +307,20 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
_replace_policy=policy)
|
||||
|
||||
|
||||
def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
|
||||
def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
|
||||
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
|
||||
Arguments:
|
||||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
|
||||
e.g., transformers.modeling_bert.BertLayer.
|
||||
model (torch.nn.Module): user's nn.module representing their model
|
||||
bert_config (dict): model config containing hidden size, attention heads, etc.
|
||||
config (dict): model config containing hidden size, attention heads, etc.
|
||||
|
||||
Returns:
|
||||
Updated nn.module with original bert-style transformer layers
|
||||
"""
|
||||
def replace_fn(child, _replace_policy, layer_id):
|
||||
#from turing.nvidia_modelingpreln import BertLayer
|
||||
orig_module = orig_layer_impl(bert_config)
|
||||
orig_module = orig_layer_impl(config)
|
||||
|
||||
# copy relevant state from child -> original module
|
||||
qkvw = child.attn_qkvw.data
|
||||
|
@ -389,8 +394,11 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
|
|||
for plcy in replace_policies:
|
||||
# instantiate a throw-away policy in order to populate the _orig_layer_class
|
||||
_ = plcy(None)
|
||||
assert plcy._orig_layer_class != None
|
||||
policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
|
||||
if plcy._orig_layer_class is not None:
|
||||
policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
|
||||
assert len(policy.items()) > 0,\
|
||||
"No default policy found! Please specifiy your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
|
||||
"You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
|
||||
|
||||
replaced_module, _ = _replace_module(model, policy)
|
||||
return replaced_module
|
||||
|
|
|
@ -74,7 +74,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
stochastic_mode=False,
|
||||
encoder_decoder=False,
|
||||
scale_attention=True,
|
||||
triangular_masking=True):
|
||||
triangular_masking=True,
|
||||
local_attention=False,
|
||||
window_size=256):
|
||||
super(DeepSpeedInferenceConfig,
|
||||
self).__init__(
|
||||
hidden_size,
|
||||
|
@ -92,6 +94,8 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
self.scale_attention = scale_attention
|
||||
self.specialized_mode = None
|
||||
self.triangular_masking = triangular_masking
|
||||
self.local_attention = local_attention
|
||||
self.window_size = window_size
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
|
@ -218,7 +222,9 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
num_attention_heads_per_partition,
|
||||
(1 / norm_factor if config.scale_attention else 1.0),
|
||||
(not unfused_mode),
|
||||
config.triangular_masking)
|
||||
config.triangular_masking,
|
||||
config.local_attention,
|
||||
config.window_size)
|
||||
else:
|
||||
attn_key_value = score_context_func(
|
||||
mixed_query,
|
||||
|
@ -230,8 +236,10 @@ class DeepSpeedSelfAttentionFunction(Function):
|
|||
num_attention_heads_per_partition,
|
||||
(1 / norm_factor if config.scale_attention else 1.0),
|
||||
(not unfused_mode),
|
||||
config.triangular_masking)
|
||||
|
||||
config.triangular_masking,
|
||||
config.local_attention,
|
||||
config.window_size)
|
||||
#import pdb;pdb.set_trace()
|
||||
if unfused_mode:
|
||||
context_layer, _, _ = attn_key_value
|
||||
else:
|
||||
|
@ -522,6 +530,8 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
|
||||
for specific downstream tasks.
|
||||
"""
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
mp_group=None,
|
||||
|
@ -533,6 +543,9 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
super(DeepSpeedTransformerInference, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.config.layer_id = DeepSpeedTransformerInference.layer_id
|
||||
DeepSpeedTransformerInference.layer_id += 1
|
||||
|
||||
self.attention = DeepSpeedSelfAttention(config,
|
||||
mp_group,
|
||||
quantize_scales,
|
||||
|
|
Загрузка…
Ссылка в новой задаче