Skip to content

Commit

Permalink
More bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 23, 2024
1 parent 940f3b0 commit 45698d2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,15 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool)
)HELPER_FN";

// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1)
// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1)
// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's.
shader.MainFunctionBody() << R"MAIN_FN(
let head_idx = workgroup_id.x;
let wave_x = (local_id.x / 4);
let wave_y = (local_id.y / 4);
let wave_x = u32(local_id.x / 4);
let wave_y = u32(local_id.y / 4);
// It is always the case that 0 <= wave_id < TILE_SIZE
let wave_id = wave_x + wave_y * 4;
let wave_id:u32 = wave_x + wave_y * 4;
let q_idx_start = workgroup_id.y * TILE_SIZE;
let q_idx_global = q_idx_start + wave_id;
Expand Down Expand Up @@ -712,7 +713,7 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_
workgroupBarrier();
if (k_idx_global_using_wave_valid)
{
for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.present_sequence_length; q_idx++)
for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++)
{
// Leveraging the subgroups for parallelism, compute dot product of QK.
// Because for the case of new_seq 1, there is a single query and context length of K
Expand Down Expand Up @@ -824,6 +825,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&

if (parameters.batch_size == 1 &&
bias == nullptr &&
past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
Expand Down

0 comments on commit 45698d2

Please sign in to comment.