You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// Note that this shader adopts similar algorithm with dp4a generation shader.
155
-
//
156
-
// This algorithm works to compute dot product of keys with queries parallelly, by processing on the k (head_size) dimension at each step amongst tile_size_k_vec threads,
157
-
// and utilizing the remaining threads in the workgroup to process additional rows of |present_key| in parallel (such that the values in shared memory (tile_q) for |q| can be reused).
158
-
// For each load of q, the tile_size_k_vec threads also reload |present_key| tile_size/sub_tile_count times to compute partial dot products of other |present_key| rows
159
-
// in order to complete all tile_size |present_key| rows in this workgroup and also reusing the loaded in register values of |q|.
160
-
constexprint tile_size_k_vec = 8;
161
-
162
-
// 1. Each workgroup processes one row of |q| and tile_size rows of |present_key|
163
-
//
164
-
// 2. Computation Process:
165
-
// - Reads [tile_size][tile_size_k_vec] block of |present_key| data at a time
166
-
// - Each thread within workgroup computes dot products of 4 A*B elements since each k represents 4 elements of |present_key|
167
-
// - Stores intermediate results in shared memory (inner_qk_values)
168
-
// - Iterates through columns (head_size_vec) accumulating results in inner_qk_values
169
-
// - Performs final reduction sum in inner_qk_values for output
// Note that this shader adopts similar algorithm with dp4a generation shader.
295
-
//
296
-
// This algorithm works to compute dot product of v with qk parallelly, by processing on the head_size dimension at each step amongst tile_size_k_vec threads,
297
-
// and utilizing the remaining threads in the workgroup to process additional rows of |present_value| in parallel (such that the values in shared memory (tile_qk) for |qk| can be reused).
298
-
// The tile_size_k_vec threads also reload |present_value| tile_size/sub_tile_count times to compute partial dot products of other |present_value| rows
299
-
// in order to complete all tile_size |present_value| rows in this workgroup and also reusing the values in tile_qk.
300
-
//
301
-
// The difference with FlashAttentionDecodeQKTProgram is that the dot products go through the rows (total_sequence_length) of |present_value| instead of columns (head_size_vec).
302
-
// And each workgroup only calculate current tile_size's dot products instead of iterating the whole row |total_sequence_length|.
303
-
// That's why this shader is a split shader. The final reduce will be done in FlashAttentionDecodeReduceProgram.
// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx and FlashAttentionDecodeVxReduce, which can also reduce the intermediate memory.
318
-
// The FlashAttentionDecodeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, some issues are met that when
319
-
// the total sequence length exceeds some value, the result will become garbage. Since it can't be resolved in a short time, leave it as TODO to fix it in future.
320
-
shader.MainFunctionBody() << R"MAIN_FN(
321
-
let local_row = u32(local_idx / tile_size_k_vec);
322
-
let local_col = local_idx % tile_size_k_vec;
323
-
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
324
-
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
325
-
var total_sequence_length = uniforms.total_sequence_length;
326
-
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.head_size_vec * uniforms.present_sequence_length;
327
-
328
-
// Calculate the global max and sum in qk.
329
-
if (head_idx < uniforms.num_heads)
330
-
{
331
-
var g_max = f32(-3.402823e+38f);
332
-
var g_sum = f32(0);
333
-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
334
-
{
335
-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
336
-
g_max = max(g_max, metadata[meta_offset].x);
337
-
}
338
-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
339
-
{
340
-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
341
-
let m_value = metadata[meta_offset];
342
-
g_sum += exp(m_value.x - g_max) * m_value.y;
343
-
}
344
-
345
-
if (total_seq_offset + local_idx < total_sequence_length) {
0 commit comments