Make static KV cache work.:
This commit is contained in:
Родитель
6d9636f07c
Коммит
9773e68439
|
@ -436,7 +436,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
|
|||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
|
||||
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
|
||||
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
|
||||
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
|
||||
const int total_sequence_length = parameters.is_gqa_ && parameters.past_present_share_buffer_ ? parameters.seqlen_present_kv_cache_ : (past_sequence_length + parameters.kv_sequence_length_);
|
||||
|
||||
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, total_sequence_length});
|
||||
|
|
Загрузка…
Ссылка в новой задаче