diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index ea5ab3d091461..a8459111b31b9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -68,22 +68,15 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(const Tensor* seqlens_k, const Tensor* total_seqlen_tensor, bool init_past_sequence_length, std::ostringstream& ss) { - if (seqlens_k != nullptr && total_seqlen_tensor != nullptr) { - ss << "let total_sequence_length_input = u32(total_seqlen_tensor[0]);\n"; - ss << "let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);\n"; - ss << "let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;\n"; - ss << "let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;\n"; +void InitVarStub(std::ostringstream& ss, const Tensor* seqlens_k) { + if (seqlens_k != nullptr) { ss << "total_sequence_length = u32(seqlens_k[batch_idx]) + 1;\n"; ss << "var past_sequence_length: u32 = 0;\n"; - ss << "if (is_first_prompt == false) {\n"; + ss << "if (uniforms.is_first_prompt != 0) {\n"; ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; ss << "}\n"; } else { - if (init_past_sequence_length) { - ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; - } - ss << "let present_sequence_length = total_sequence_length;\n"; + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; } } @@ -99,9 +92,6 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_ != nullptr) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_ != nullptr) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (has_present_key_) { shader.AddOutput("present_key", ShaderUsage::UseUniform); @@ -118,7 +108,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" @@ -163,7 +153,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_key_) { - shader.MainFunctionBody() << " if (n + local_id.y < present_sequence_length) {\n" + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n" << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; } @@ -194,7 +184,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, - const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -205,7 +195,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.n_reps, seqlen_k, total_seqlen_tensor}; + components, parameters.n_reps, seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -214,9 +204,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); if (has_present_key) { @@ -228,7 +217,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -237,7 +226,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(alpha)}, {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, - {static_cast(parameters.n_reps)}}) + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}, + {static_cast(parameters.is_first_prompt_)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -247,9 +238,6 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" @@ -259,7 +247,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.sequence_length;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" @@ -304,7 +292,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k, bool is_first_prompt) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = total_sequence_length / components; @@ -313,12 +301,12 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k, total_seqlen_tensor}; - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .CacheHint(std::to_string(work_group_size), is_first_prompt) .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -326,7 +314,8 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(past_sequence_length)}, {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, - {static_cast(elementsPerThread)}}); + {static_cast(elementsPerThread)}, + {static_cast(is_first_prompt)}}); return context.RunProgram(program); } @@ -340,9 +329,6 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("output", ShaderUsage::UseUniform); if (has_present_value_) { shader.AddOutput("present_value", ShaderUsage::UseUniform); @@ -358,7 +344,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" @@ -404,7 +390,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << " if (w + local_id.y < present_sequence_length) {\n" + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n" << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; } @@ -436,21 +422,19 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, - const Tensor* seqlen_k, - const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k) { const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, total_seqlen_tensor}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); } - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); if (has_present_value) { @@ -460,6 +444,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -469,16 +454,17 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.v_hidden_size_)}, {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, - {static_cast(parameters.n_reps)}}) + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}, + {static_cast(parameters.is_first_prompt_)}}) .SetOverridableConstants({{static_cast(tile_size)}}); - ; return context.RunProgram(program); } Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; @@ -488,13 +474,13 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const TensorShape probs_shape(probs_dims); Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters, past_sequence_length, total_sequence_length, seqlen_k)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index ea0ad7e03fc54..8c6e27b9f9227 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -48,7 +48,9 @@ class AttentionProbsProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}); + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -60,13 +62,12 @@ class AttentionProbsProgram final : public Program { int components_; int n_reps_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -76,19 +77,19 @@ class InPlaceSoftmaxProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; int components_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -101,7 +102,9 @@ class VxAttentionScoreProgram final : public Program { {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}); + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -111,7 +114,6 @@ class VxAttentionScoreProgram final : public Program { int tile_size_; int n_reps_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 230207ed26c1a..c010ffd49b86b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -43,6 +43,7 @@ struct WebgpuAttentionParameters { batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), kv_sequence_length_(parameters.sequence_length), + past_sequence_length_(parameters.seqlen_past_kv_cache), total_sequence_length_(parameters.total_sequence_length), hidden_size_(parameters.hidden_size), head_size_(parameters.head_size), @@ -56,6 +57,8 @@ struct WebgpuAttentionParameters { kv_num_heads_(parameters.kv_num_heads), num_splits_(parameters.num_splits), rotary_dim_(parameters.rotary_dim), + is_subsequent_prompt_(parameters.is_subsequent_prompt), + is_first_prompt_(parameters.is_first_prompt), rotary_interleaved_(parameters.rotary_interleaved), use_smooth_softmax_(parameters.use_smooth_softmax), softcap_(parameters.softcap), @@ -118,7 +121,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr); + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 5c1f032d976aa..1dbeeeda20164 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -60,7 +60,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); - output_shape[2] = static_cast(parameters.v_hidden_size_); + output_shape[2] = static_cast(parameters.hidden_size_); Tensor* output = context.Output(0, output_shape); const int present_kv_seqlen = parameters.seqlen_present_kv_cache_; std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); @@ -77,7 +77,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& parameters.kv_sequence_length_, parameters.head_size_}); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k, total_seqlen_tensor); + present_value, parameters, context, seqlen_k); } TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); @@ -91,7 +91,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 2 * parameters.hidden_size_, &V)); return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k, total_seqlen_tensor); + present_value, parameters, context, seqlen_k); } } // namespace webgpu