Add local attention for GPT-Neo model architecture (#1114)

* fix links for inference tutorial

* Fix automatic injection. Add the local-attention for GPT-Neo

* fix the inference for generation of large sequences (>1K & <32K)

* fix format

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Reza Yazdani 2021-06-08 11:44:59 -07:00 коммит произвёл GitHub
Родитель 8def3cb3a2
Коммит aca7fc549a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 242 добавлений и 97 удалений

Просмотреть файл

@ -13,7 +13,9 @@ template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
bool triangular,
bool recompute)
bool recompute,
bool local_attention,
int window_size)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
@ -25,6 +27,8 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
(T*)attn_mask.data_ptr(),
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
@ -47,7 +51,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
int& heads,
float& norm_factor,
bool triangular,
bool recompute)
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
@ -75,7 +81,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(attn_score, attn_mask, triangular, recompute);
attn_score =
ds_softmax<T>(attn_score, attn_mask, triangular, recompute, local_attention, window_size);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
@ -105,7 +112,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
int heads,
float norm_factor,
bool merging,
bool triangular)
bool triangular,
bool local_attention,
int window_size)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
@ -138,7 +147,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0));
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}

Просмотреть файл

@ -6,10 +6,8 @@
#include <cstdlib>
#include <ctime>
#define Attn_Threads 128
#define Reduce_Threads 32
#define attn_warps 4
#define MAX_ATTN_REG 4 // MAX Head Size 256
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity (-1 * std::numeric_limits<float>::infinity())
@ -26,31 +24,40 @@ void CheckCudaErrorAux(const char* file, unsigned line)
namespace cg = cooperative_groups;
template <int tbSize, int tbSeq>
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale)
float scale,
int iterations,
int reduceWidth)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[tbSeq];
float2 high_data[tbSeq];
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int iter_offset = blockIdx.x * (blockDim.x >> 5) + wid;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
@ -58,20 +65,33 @@ __global__ void attn_softmax_v2(__half* vals,
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < tbSeq; i++) {
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) {
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = __half2float(vals[data_id]);
low_data[i].y = (!triangular || ((data_id + 1) <= seq_id))
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (!triangular || ((data_id + 2) <= seq_id))
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = (!triangular || ((data_id + 3) <= seq_id))
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && !triangular && recompute) {
@ -81,12 +101,15 @@ __global__ void attn_softmax_v2(__half* vals,
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = __half2float(vals[data_id]);
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id)) &&
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
@ -112,13 +135,29 @@ __global__ void attn_softmax_v2(__half* vals,
}
}
for (int i = 1; i < tbSize; i *= 2) {
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < tbSeq; i++) {
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
@ -127,12 +166,24 @@ __global__ void attn_softmax_v2(__half* vals,
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i);
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < tbSeq; i++) {
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
@ -151,46 +202,69 @@ __global__ void attn_softmax_v2(__half* vals,
#endif
}
template <int tbSize, int tbSeq>
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale)
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[tbSeq];
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int iter_offset = blockIdx.x * warp_num + wid;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < tbSeq; i++) {
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && data_id < sequence_length) {
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (vals[data_id]);
data[i].y = (!triangular || ((data_id + 1) <= seq_id)) ? (vals[data_id + 1])
: minus_infinity;
data[i].z = (!triangular || ((data_id + 2) <= seq_id)) ? (vals[data_id + 2])
: minus_infinity;
data[i].w = (!triangular || ((data_id + 3) <= seq_id)) ? (vals[data_id + 3])
: minus_infinity;
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && !triangular && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
@ -198,13 +272,13 @@ __global__ void attn_softmax_v2(float* vals,
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = (vals[data_id]);
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) < sequence_length)
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) < sequence_length)
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
@ -228,13 +302,29 @@ __global__ void attn_softmax_v2(float* vals,
}
}
for (int i = 1; i < tbSize; i *= 2) {
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < tbSeq; i++) {
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
@ -243,12 +333,24 @@ __global__ void attn_softmax_v2(float* vals,
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) sum += g.shfl_xor(sum, i);
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < tbSeq; i++) {
int data_id = i * (WARP_SIZE << 2) + (lane << 2);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
@ -271,6 +373,8 @@ void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
@ -279,36 +383,36 @@ void launch_attn_softmax_v2(T* vals,
cudaStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / 32 + 1);
dim3 block_dim(1024);
if (sequence_length <= 128)
attn_softmax_v2<32, 1><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
else if (sequence_length <= 256)
attn_softmax_v2<32, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
else if (sequence_length <= 512)
attn_softmax_v2<32, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
else if (sequence_length <= 1024)
attn_softmax_v2<32, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
else if (sequence_length <= 2048)
attn_softmax_v2<32, 16><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
else if (sequence_length <= 4096)
attn_softmax_v2<32, 32><<<grid_dim, block_dim, 0, stream>>>(
vals, mask, triangular, recompute, total_count, heads, sequence_length, num_seq, scale);
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
heads,
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
@ -319,6 +423,8 @@ template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,

Просмотреть файл

@ -18,6 +18,8 @@ void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,

Просмотреть файл

@ -78,8 +78,12 @@ class InferenceEngine(Module):
def _get_model_config_generate(self):
if hasattr(self.module, 'config'):
self.config = self.module.config
else:
self.config = None
if hasattr(self.module, 'generate'):
self.generate = self.module.generate
else:
self.generate = None
def _create_model_parallel_group(self):
# Call the init process
@ -134,6 +138,7 @@ class InferenceEngine(Module):
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
quantize=(self.dtype == torch.int8),

Просмотреть файл

@ -90,7 +90,7 @@ def replace_transformer_layer(orig_layer_impl,
model,
policy=None,
micro_batch_size=-1,
bert_config=None,
config=None,
seed=-1,
hidden_size=-1,
num_attention_heads=-1,
@ -110,7 +110,7 @@ def replace_transformer_layer(orig_layer_impl,
model (torch.nn.Module): user's nn.module representing their model
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
hidden_size (int): hidden dimension
@ -155,7 +155,6 @@ def replace_transformer_layer(orig_layer_impl,
_4hh_w = _4hh_w.half()
if quantize or fp16:
qkvb = qkvb.half()
dense_b = dense_b.half()
_h4h_b = _h4h_b.half()
_4hh_b = _4hh_b.half()
@ -175,7 +174,12 @@ def replace_transformer_layer(orig_layer_impl,
mp_size=mp_size,
q_int8=quantize,
encoder_decoder=(True if policy_cls is HFBertLayerPolicy else False),
triangular_masking=(policy_cls is not HFBertLayerPolicy))
triangular_masking=(policy_cls is not HFBertLayerPolicy),
local_attention=((config.attention_layers[layer_id] == "local")
if hasattr(config,
'attention_layers') else False),
window_size=(config.window_size if hasattr(config,
'window_size') else 1))
if quantize and quantize_settings is not None:
(quantization_scales,
@ -208,8 +212,8 @@ def replace_transformer_layer(orig_layer_impl,
new_module.config.scale_attention = scale_attention
# we want the weights in [input, output] shape
# linear layer is created is created with [input, output] shape
# we transpose it here to reduce inference cost!
# linear layer is created with [input, output] shape
# transpose it here to reduce inference cost!
def transpose(data):
data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1))
data = data.reshape(data.shape[-1], data.shape[-2])
@ -228,6 +232,7 @@ def replace_transformer_layer(orig_layer_impl,
qkvw)
if qkvb is not None:
qkvb = qkvb.half()
attn_block.attn_qkvb.data = mp_replace.qkv_copy(
attn_block.attn_qkvb.data,
qkvb)
@ -250,12 +255,12 @@ def replace_transformer_layer(orig_layer_impl,
else:
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln),
@ -302,20 +307,20 @@ def replace_transformer_layer(orig_layer_impl,
_replace_policy=policy)
def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
config (dict): model config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child, _replace_policy, layer_id):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
orig_module = orig_layer_impl(config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
@ -389,8 +394,11 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
for plcy in replace_policies:
# instantiate a throw-away policy in order to populate the _orig_layer_class
_ = plcy(None)
assert plcy._orig_layer_class != None
policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
if plcy._orig_layer_class is not None:
policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
assert len(policy.items()) > 0,\
"No default policy found! Please specifiy your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
"You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
replaced_module, _ = _replace_module(model, policy)
return replaced_module

Просмотреть файл

@ -74,7 +74,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
stochastic_mode=False,
encoder_decoder=False,
scale_attention=True,
triangular_masking=True):
triangular_masking=True,
local_attention=False,
window_size=256):
super(DeepSpeedInferenceConfig,
self).__init__(
hidden_size,
@ -92,6 +94,8 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.scale_attention = scale_attention
self.specialized_mode = None
self.triangular_masking = triangular_masking
self.local_attention = local_attention
self.window_size = window_size
@classmethod
def from_dict(cls, json_object):
@ -218,7 +222,9 @@ class DeepSpeedSelfAttentionFunction(Function):
num_attention_heads_per_partition,
(1 / norm_factor if config.scale_attention else 1.0),
(not unfused_mode),
config.triangular_masking)
config.triangular_masking,
config.local_attention,
config.window_size)
else:
attn_key_value = score_context_func(
mixed_query,
@ -230,8 +236,10 @@ class DeepSpeedSelfAttentionFunction(Function):
num_attention_heads_per_partition,
(1 / norm_factor if config.scale_attention else 1.0),
(not unfused_mode),
config.triangular_masking)
config.triangular_masking,
config.local_attention,
config.window_size)
#import pdb;pdb.set_trace()
if unfused_mode:
context_layer, _, _ = attn_key_value
else:
@ -522,6 +530,8 @@ class DeepSpeedTransformerInference(nn.Module):
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
for specific downstream tasks.
"""
layer_id = 0
def __init__(self,
config,
mp_group=None,
@ -533,6 +543,9 @@ class DeepSpeedTransformerInference(nn.Module):
super(DeepSpeedTransformerInference, self).__init__()
self.config = config
self.config.layer_id = DeepSpeedTransformerInference.layer_id
DeepSpeedTransformerInference.layer_id += 1
self.attention = DeepSpeedSelfAttention(config,
mp_group,
quantize_scales,