diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 4934845ce1cfd..0af6d2d4e907a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -673,14 +673,15 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) )HELPER_FN"; -// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1) +// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) +// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; -let wave_x = (local_id.x / 4); -let wave_y = (local_id.y / 4); +let wave_x = u32(local_id.x / 4); +let wave_y = u32(local_id.y / 4); // It is always the case that 0 <= wave_id < TILE_SIZE -let wave_id = wave_x + wave_y * 4; +let wave_id:u32 = wave_x + wave_y * 4; let q_idx_start = workgroup_id.y * TILE_SIZE; let q_idx_global = q_idx_start + wave_id; @@ -712,7 +713,7 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ workgroupBarrier(); if (k_idx_global_using_wave_valid) { - for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.present_sequence_length; q_idx++) + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) { // Leveraging the subgroups for parallelism, compute dot product of QK. // Because for the case of new_seq 1, there is a single query and context length of K @@ -824,6 +825,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (parameters.batch_size == 1 && bias == nullptr && + past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,