This commit is contained in:
Satya Jandhyala 2024-12-09 17:28:03 -08:00
Родитель 6d9636f07c
Коммит 9773e68439
1 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -436,7 +436,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const int total_sequence_length = parameters.is_gqa_ && parameters.past_present_share_buffer_ ? parameters.seqlen_present_kv_cache_ : (past_sequence_length + parameters.kv_sequence_length_);
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, total_sequence_length});