diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index dfe188a23f5a5..51dc53a41b2d7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -569,44 +569,36 @@ fn debugKTile() -> q_value_t return sum_value; } -fn loadk(slot: u32, k_idx_global : u32, head_idx: u32) +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { - if (k_idx_global >= uniforms.present_sequence_length) { - return; - } - // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) { var value = present_key[idx+offset]; k_tile[slot][idx] = value; } } -fn loadv(slot: u32, v_idx_global : u32, head_idx: u32) +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { - if (v_idx_global >= uniforms.present_sequence_length) { - return; - } - // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx ++) + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) { v_tile[slot][idx] = present_value[idx+offset]; } } -fn loadAttentionBias(qtile_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { - qk_tile[qtile_row][k_col] = 0.0; + qk_tile[q_row][k_col] = 0.0; return; } let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; - qk_tile[qtile_row][k_col] = attention_bias[offset]; + qk_tile[q_row][k_col] = attention_bias[offset]; } fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -705,37 +697,54 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) )HELPER_FN"; // Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1) -// QKV_HEAD_VECTORIZED_SIZE % sg_id == 0 for loadq, loadk and computeDotProduct to work right. - shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; -// Split the composite workgroup id into actual y and subgroup id. -let q_tile_row = u32(local_idx / sg_size); -let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row; -// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query. -loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); -max_tile[sg_id] = MIN_VALUE; +let wave_x = (local_id.x / 4); +let wave_y = (local_id.y / 4); +let wave_id = wave_x + wave_y * 4; + +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + wave_id; +// Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. +loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); +if (sg_id == 0) +{ + max_tile[wave_id] = MIN_VALUE; +} for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { - if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) { - loadk(sg_id, k_start+sg_id, head_idx); - loadv(sg_id, k_start+sg_id, head_idx); - loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx); + let k_idx_global = k_start+wave_id; + let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; + if (k_idx_global_using_wave_valid) { + // Leveraging the subgroup lanes for parallelism, load into slot wave_id + // K/V values from k_start+wave_id. + loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); + loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and (k_start+sg_id). + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); - // Do k_idx + k_start <= q_idx_global if we want only look past. - for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) + if (k_idx_global_using_wave_valid) { - computeDotProduct(q_tile_row, k_idx, sg_id, sg_size); + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.present_sequence_length; q_idx++) + { + // Leveraging the subgroups for parallelism, compute dot product of QK. + // Because for the case of new_seq 1, there is a single query and context length of K + // we iterate over q and use the waves for K so that this step can use all the waves in + // in the workgroup. + computeDotProduct(q_idx, wave_id, sg_id, sg_size); + } } - let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(q_tile_row, sg_id, enabled); - computeO(q_tile_row, sg_id, enabled); + let k_idx_global_using_lane_valid:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(wave_id, sg_id, k_idx_global_using_lane_valid); + computeO(wave_id, sg_id, k_idx_global_using_lane_valid); } workgroupBarrier(); -writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); )MAIN_FN"; @@ -747,33 +756,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - // // Uncomment to test CopyKVCache independent of FlashAttentionProgram. - // TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.sequence_length, parameters.head_size}); - // TensorShape q_new_shape(q_new_dims); - // Tensor Qn = context.CreateGPUTensor(Q->DataType(), q_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - // context, parameters.num_heads, parameters.sequence_length, parameters.head_size, Q, nullptr, 0, &Qn)); - - // TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.kv_sequence_length, parameters.head_size}); - // TensorShape k_new_shape(k_new_dims); - // Tensor Kn = context.CreateGPUTensor(K->DataType(), k_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - // parameters.head_size, K, nullptr, parameters.hidden_size, &Kn)); - - // TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.kv_sequence_length, parameters.v_head_size}); - // TensorShape v_new_shape(v_new_dims); - // Tensor Vn = context.CreateGPUTensor(V->DataType(), v_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - // parameters.v_head_size, V, nullptr, 2 * parameters.hidden_size, &Vn)); - - // return ApplyAttention(&Qn, &Kn, &Vn, attention_bias, past_key, past_value, output, present_key, - // present_value, parameters, context, true); - - constexpr int subgroup_size = 8; - constexpr int tile_size = 8; + constexpr int subgroup_size = 16; + constexpr int tile_size = 16; bool has_attention_bias = attention_bias != nullptr; FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, @@ -789,7 +773,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co std::to_string(parameters.head_size) + std::to_string(parameters.num_heads); program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(subgroup_size*tile_size) + .SetWorkgroupSize(subgroup_size, subgroup_size) .CacheHint(cache_hint) .AddUniformVariables({{static_cast(parameters.sequence_length)}, {static_cast(parameters.total_sequence_length)}, @@ -854,7 +838,6 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (parameters.batch_size == 1 && bias == nullptr && - past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,