Skip to content

Commit

Permalink
Improve comments and bounds checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 23, 2024
1 parent e8bb833 commit 940f3b0
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
{
if (q_idx_global >= uniforms.new_sequence_length) {
return;
}
// Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA
// This is the layout if TransferBSDToBNSH has not been run.
let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx;
Expand Down Expand Up @@ -573,7 +570,7 @@ fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32)
fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32)
{
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) {
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) {
qk_tile[q_row][k_col] = 0.0;
return;
}
Expand All @@ -583,9 +580,6 @@ fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global :
fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
{
if (o_idx_global >= uniforms.new_sequence_length) {
return;
}
// Stored as float16[batch_size,sequence_length,3072]
let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size)
Expand All @@ -598,6 +592,8 @@ fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u
fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32)
{
var sum:vec4<q_element_t> = q_value_t(0, 0, 0, 0);
// idx is not initialized to sg_id to ensure uniformity because the loop uses
// subgroupAdd and unused lanes need to be initialized with 0 for correctness.
for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size)
{
var result = q_value_t(0);
Expand Down Expand Up @@ -641,6 +637,7 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool)
var d = dleft + sum;
if (d == 0)
{
// Avoid division by zero by setting d to a really small value.
d = 0.0000001h;
}
qk_tile[q_idx][sg_id] = value / d;
Expand Down Expand Up @@ -682,12 +679,17 @@ let head_idx = workgroup_id.x;
let wave_x = (local_id.x / 4);
let wave_y = (local_id.y / 4);
// It is always the case that 0 <= wave_id < TILE_SIZE
let wave_id = wave_x + wave_y * 4;
let q_idx_start = workgroup_id.y * TILE_SIZE;
let q_idx_global = q_idx_start + wave_id;
// Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query.
loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size);
let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length;
if (q_idx_global_using_wave_valid)
{
// Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query.
loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size);
}
if (sg_id == 0)
{
max_tile[wave_id] = MIN_VALUE;
Expand All @@ -704,7 +706,7 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_
loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size);
// Next, we want for every q row (wave_id) to populate bias for new sequence length
// (k_start+sg_id). loadAttentionBias handles range checking q_idx_global,
// and (k_start+sg_id).
// and sg_id, (k_start+sg_id).
loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
}
workgroupBarrier();
Expand All @@ -716,16 +718,20 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_
// Because for the case of new_seq 1, there is a single query and context length of K
// we iterate over q and use the waves for K so that this step can use all the waves in
// in the workgroup.
// We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to
// validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE.
computeDotProduct(q_idx, wave_id, sg_id, sg_size);
}
}
let k_idx_global_using_lane_valid:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(wave_id, sg_id, k_idx_global_using_lane_valid);
computeO(wave_id, sg_id, k_idx_global_using_lane_valid);
let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(wave_id, sg_id, wave_lane_valid);
computeO(wave_id, sg_id, wave_lane_valid);
}
workgroupBarrier();
writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size);
if (q_idx_global_using_wave_valid)
{
writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size);
}
)MAIN_FN";

return Status::OK();
Expand Down

0 comments on commit 940f3b0

Please sign in to comment.