diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index dfe188a23f5a5..00351913eeab1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -555,20 +555,6 @@ fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u3 } } -fn debugKTile() -> q_value_t -{ - var sum_value = q_value_t(0); - for (var qidx:u32 = 0; qidx < TILE_SIZE; qidx++) - { - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) - { - var value = k_tile[qidx][idx]; - sum_value += value; - } - } - return sum_value; -} - fn loadk(slot: u32, k_idx_global : u32, head_idx: u32) { if (k_idx_global >= uniforms.present_sequence_length) { @@ -709,33 +695,34 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; -// Split the composite workgroup id into actual y and subgroup id. -let q_tile_row = u32(local_idx / sg_size); +// Split the composite workgroup id into subgroup_cluster_id and subgroup_id. +let subgroup_cluster_id = u32(local_idx / sg_size); -let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row; -// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query. -loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + subgroup_cluster_id; +// Each invocation (subgroup_cluster_id) gets x threads (subgroup threads) and is responsible for 1 query. +loadq(subgroup_cluster_id, q_idx_global, head_idx, sg_id, sg_size); max_tile[sg_id] = MIN_VALUE; for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { - if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) { + let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + if (enabled) { loadk(sg_id, k_start+sg_id, head_idx); loadv(sg_id, k_start+sg_id, head_idx); - loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx); + loadAttentionBias(subgroup_cluster_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); // Do k_idx + k_start <= q_idx_global if we want only look past. for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) { - computeDotProduct(q_tile_row, k_idx, sg_id, sg_size); + computeDotProduct(subgroup_cluster_id, k_idx, sg_id, sg_size); } - let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(q_tile_row, sg_id, enabled); - computeO(q_tile_row, sg_id, enabled); + computeSoftMax(subgroup_cluster_id, sg_id, enabled); + computeO(subgroup_cluster_id, sg_id, enabled); } workgroupBarrier(); -writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +writeo(subgroup_cluster_id, q_idx_global, head_idx, sg_id, sg_size); )MAIN_FN";