diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 0bc834e1af93d..ea5ab3d091461 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -125,7 +125,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_head_idx = head_idx / uniforms.n_reps;\n" << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" - << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";; + << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; if (feed_past_key_ && has_present_key_) { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; } @@ -201,7 +201,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; const bool has_present_key = output_count > 1 && past_key; const bool has_attention_bias = attention_bias != nullptr; - const int tile_size = 12; + constexpr int tile_size = 12; const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, @@ -257,13 +257,13 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" << "let sequence_length = uniforms.sequence_length;\n" - << "var total_sequence_length = uniforms.total_sequence_length;\n"; + << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "total_sequence_length") << ";\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" @@ -295,7 +295,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << "}\n"; if (seqlen_k_) { - shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {\n" + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < total_sequence_length; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" << "}\n"; } @@ -319,7 +319,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(batch_size * num_heads * sequence_length) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 17f7b9e9c5fcd..ea0ad7e03fc54 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -75,7 +75,7 @@ class InPlaceSoftmaxProgram final : public Program { {"num_heads", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); private: