diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 20e45f768017a..891aa3a425b70 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -68,13 +68,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) { if (seqlen_k != nullptr) { ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; - ss << "var past_sequence_length: u32 = 0;\n"; - ss << "if (uniforms.is_first_prompt != 0) {\n"; - ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; - ss << "}\n"; + ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n"; } else { ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; } @@ -108,7 +105,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" @@ -122,7 +119,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; @@ -132,7 +129,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } } @@ -203,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -225,7 +222,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -235,8 +232,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, - {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.n_reps)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -255,7 +251,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.sequence_length;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" @@ -309,12 +305,12 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(std::to_string(work_group_size)) + .CacheHint(work_group_size, is_first_prompt) .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -322,8 +318,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(past_sequence_length)}, {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, - {static_cast(elementsPerThread)}, - {static_cast(is_first_prompt)}}); + {static_cast(elementsPerThread)}}); return context.RunProgram(program); } @@ -352,7 +347,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" @@ -366,7 +361,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; } } else { shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; @@ -377,7 +372,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; } } @@ -407,7 +402,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } - shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.present_sequence_length] = tileK[idx];\n" << " }\n"; } @@ -443,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { @@ -460,7 +455,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -471,8 +466,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, - {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.n_reps)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index ee3aa1957cdd9..03279fffbc3ef 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -49,8 +49,7 @@ class AttentionProbsProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"n_reps", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -63,12 +62,13 @@ class AttentionProbsProgram final : public Program { int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; + bool is_first_prompt_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -78,19 +78,19 @@ class InPlaceSoftmaxProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; int components_; const Tensor* seqlen_k_; + bool is_first_prompt_; }; class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -104,8 +104,7 @@ class VxAttentionScoreProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"n_reps", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -116,6 +115,7 @@ class VxAttentionScoreProgram final : public Program { int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; + bool is_first_prompt_; }; } // namespace webgpu