Skip to content

Commit

Permalink
Merge branch 'sajandhy/webgpu-ep-gqa-new' of https://github.com/micro…
Browse files Browse the repository at this point in the history
…soft/onnxruntime into sajandhy/webgpu-ep-gqa-new
  • Loading branch information
satyajandhyala committed Nov 1, 2024
2 parents f1c7d17 + 2d21742 commit ce7a583
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let kv_head_idx = head_idx / uniforms.n_reps;\n"
<< "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n"
<< "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n"
<< "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";;
<< "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";
if (feed_past_key_ && has_present_key_) {
shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n";
}
Expand Down Expand Up @@ -201,7 +201,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0;
const bool has_present_key = output_count > 1 && past_key;
const bool has_attention_bias = attention_bias != nullptr;
const int tile_size = 12;
constexpr int tile_size = 12;
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
Expand Down Expand Up @@ -257,13 +257,13 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
<< "let sequence_length = uniforms.sequence_length;\n"
<< "var total_sequence_length = uniforms.total_sequence_length;\n";
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "total_sequence_length") << ";\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
Expand Down Expand Up @@ -295,7 +295,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " }\n"
<< "}\n";
if (seqlen_k_) {
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {\n"
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < total_sequence_length; total_seq_id++) {\n"
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
<< "}\n";
}
Expand All @@ -319,7 +319,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
{total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});

private:
Expand Down

0 comments on commit ce7a583

Please sign in to comment.