Skip to content

Commit eb453df

Browse files
authored
Move flash decoding shaders into templates (microsoft#25774)
### Description Put the flash decoding shader into three template files. ### Motivation and Context Moving to templates will improve code readability.
1 parent 38ef3ad commit eb453df

File tree

4 files changed

+206
-230
lines changed

4 files changed

+206
-230
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 15 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -151,103 +151,14 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
151151
}
152152
shader.AddOutput("output", ShaderUsage::UseUniform);
153153
shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
154-
// Note that this shader adopts similar algorithm with dp4a generation shader.
155-
//
156-
// This algorithm works to compute dot product of keys with queries parallelly, by processing on the k (head_size) dimension at each step amongst tile_size_k_vec threads,
157-
// and utilizing the remaining threads in the workgroup to process additional rows of |present_key| in parallel (such that the values in shared memory (tile_q) for |q| can be reused).
158-
// For each load of q, the tile_size_k_vec threads also reload |present_key| tile_size/sub_tile_count times to compute partial dot products of other |present_key| rows
159-
// in order to complete all tile_size |present_key| rows in this workgroup and also reusing the loaded in register values of |q|.
160-
constexpr int tile_size_k_vec = 8;
161-
162-
// 1. Each workgroup processes one row of |q| and tile_size rows of |present_key|
163-
//
164-
// 2. Computation Process:
165-
// - Reads [tile_size][tile_size_k_vec] block of |present_key| data at a time
166-
// - Each thread within workgroup computes dot products of 4 A*B elements since each k represents 4 elements of |present_key|
167-
// - Stores intermediate results in shared memory (inner_qk_values)
168-
// - Iterates through columns (head_size_vec) accumulating results in inner_qk_values
169-
// - Performs final reduction sum in inner_qk_values for output
170-
shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n"
171-
<< "const tile_size_k_vec = " << tile_size_k_vec << "u;\n"
172-
<< "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n";
173-
shader.AdditionalImplementation() << R"ADDNL_FN(
174-
var<workgroup> tile_q: array<q_value_t, tile_size_k_vec>;
175-
var<workgroup> inner_qk_values: array<array<q_element_t, tile_size_k_vec>, tile_size>;
176-
var<workgroup> tile_qk: array<q_element_t, tile_size>;
177-
)ADDNL_FN";
178-
179-
if (has_attention_bias_) {
180-
shader.AdditionalImplementation() << R"HELPER_FN(
181-
fn loadAttentionBias(idx: u32) -> q_element_t
182-
{
183-
return attention_bias[idx];
184-
}
185-
)HELPER_FN";
186-
} else {
187-
shader.AdditionalImplementation() << R"HELPER_FN(
188-
fn loadAttentionBias(idx: u32) -> q_element_t
189-
{
190-
return q_element_t(0);
191-
}
192-
)HELPER_FN";
193-
}
194-
195-
shader.MainFunctionBody() << R"MAIN_FN(
196-
let local_row = u32(local_idx / tile_size_k_vec);
197-
let local_col = local_idx % tile_size_k_vec;
198-
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
199-
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
200-
let q_offset = head_idx * uniforms.head_size_vec;
201-
var total_sequence_length = uniforms.total_sequence_length;
202-
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec;
203-
for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
204-
if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
205-
tile_q[local_idx] = q[q_offset + k + local_idx];
206-
}
207-
workgroupBarrier();
208-
let q_data = tile_q[local_col] * q_element_t(uniforms.alpha);
209-
if (k + local_col < uniforms.head_size_vec) {
210-
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
211-
if (total_seq_offset + row_offset + local_row < total_sequence_length) {
212-
inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data);
213-
}
214-
}
215-
}
216-
workgroupBarrier();
217-
}
218154

219-
if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length && head_idx < uniforms.num_heads) {
220-
var sum = q_element_t(0);
221-
for (var i = 0u; i < tile_size_k_vec; i++) {
222-
sum += inner_qk_values[local_idx][i];
223-
}
224-
225-
sum = sum + loadAttentionBias(head_idx * total_sequence_length + total_seq_offset + local_idx);
226-
tile_qk[local_idx] = sum;
227-
output[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum;
228-
}
229-
workgroupBarrier();
230-
231-
if (head_idx >= uniforms.num_heads) {
232-
return;
233-
}
234-
235-
if (local_idx == 0u) {
236-
// Calculate the max and sum in current split.
237-
var l_max = f32(-3.402823e+38f);
238-
var l_sum = f32(0);
239-
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
240-
l_max = max(l_max, f32(tile_qk[i]));
241-
}
242-
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
243-
l_sum += exp(f32(tile_qk[i]) - l_max);
244-
}
245-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile;
246-
metadata[meta_offset] = metadata_value_t(l_max, l_sum);
247-
}
248-
)MAIN_FN";
249-
250-
return Status::OK();
155+
const uint32_t tile_size_k_vec = 8;
156+
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
157+
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template",
158+
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
159+
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
160+
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
161+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
251162
}
252163

253164
Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
@@ -291,96 +202,13 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
291202
shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
292203
shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
293204

294-
// Note that this shader adopts similar algorithm with dp4a generation shader.
295-
//
296-
// This algorithm works to compute dot product of v with qk parallelly, by processing on the head_size dimension at each step amongst tile_size_k_vec threads,
297-
// and utilizing the remaining threads in the workgroup to process additional rows of |present_value| in parallel (such that the values in shared memory (tile_qk) for |qk| can be reused).
298-
// The tile_size_k_vec threads also reload |present_value| tile_size/sub_tile_count times to compute partial dot products of other |present_value| rows
299-
// in order to complete all tile_size |present_value| rows in this workgroup and also reusing the values in tile_qk.
300-
//
301-
// The difference with FlashAttentionDecodeQKTProgram is that the dot products go through the rows (total_sequence_length) of |present_value| instead of columns (head_size_vec).
302-
// And each workgroup only calculate current tile_size's dot products instead of iterating the whole row |total_sequence_length|.
303-
// That's why this shader is a split shader. The final reduce will be done in FlashAttentionDecodeReduceProgram.
304-
constexpr int tile_size_k_vec = 8;
305-
306-
shader.AdditionalImplementation() << "const head_size_vec = " << head_size_vec_ << "u;\n"
307-
<< "const tile_size = " << tile_size_ << "u;\n"
308-
<< "const tile_size_k_vec = " << tile_size_k_vec << "u;\n"
309-
<< "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n";
310-
shader.AdditionalImplementation() << R"HELPER_FN(
311-
var<workgroup> tile_qk: array<present_value_element_t, tile_size>;
312-
var<workgroup> tile_output: array<present_value_value_t, head_size_vec>;
313-
var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>, sub_tile_count>;
314-
315-
)HELPER_FN";
316-
317-
// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx and FlashAttentionDecodeVxReduce, which can also reduce the intermediate memory.
318-
// The FlashAttentionDecodeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, some issues are met that when
319-
// the total sequence length exceeds some value, the result will become garbage. Since it can't be resolved in a short time, leave it as TODO to fix it in future.
320-
shader.MainFunctionBody() << R"MAIN_FN(
321-
let local_row = u32(local_idx / tile_size_k_vec);
322-
let local_col = local_idx % tile_size_k_vec;
323-
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
324-
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
325-
var total_sequence_length = uniforms.total_sequence_length;
326-
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.head_size_vec * uniforms.present_sequence_length;
327-
328-
// Calculate the global max and sum in qk.
329-
if (head_idx < uniforms.num_heads)
330-
{
331-
var g_max = f32(-3.402823e+38f);
332-
var g_sum = f32(0);
333-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
334-
{
335-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
336-
g_max = max(g_max, metadata[meta_offset].x);
337-
}
338-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
339-
{
340-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
341-
let m_value = metadata[meta_offset];
342-
g_sum += exp(m_value.x - g_max) * m_value.y;
343-
}
344-
345-
if (total_seq_offset + local_idx < total_sequence_length) {
346-
tile_qk[local_idx] = present_value_element_t(exp(f32(qk[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum);
347-
}
348-
}
349-
for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
350-
var value = present_value_value_t(0);
351-
qkv_values[local_row][local_col] = present_value_value_t(0);
352-
workgroupBarrier();
353-
354-
if (k + local_col < uniforms.head_size_vec) {
355-
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
356-
if (total_seq_offset + row_offset + local_row < total_sequence_length) {
357-
value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col] * tile_qk[row_offset + local_row];
358-
}
359-
}
360-
}
361-
362-
qkv_values[local_row][local_col] = value;
363-
workgroupBarrier();
364-
365-
if (local_idx < tile_size_k_vec) {
366-
for (var i = 0u; i < sub_tile_count; i++) {
367-
tile_output[k + local_idx] += qkv_values[i][local_idx];
368-
}
369-
}
370-
workgroupBarrier();
371-
}
205+
const uint32_t tile_size_k_vec = 8u;
372206

373-
if (head_idx >= uniforms.num_heads) {
374-
return;
375-
}
376-
377-
for (var i = local_idx; i < uniforms.head_size_vec; i += workgroup_size_x) {
378-
let out_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * uniforms.head_size_vec + i;
379-
out_split_vx[out_offset] = tile_output[i];
380-
}
381-
)MAIN_FN";
382-
383-
return Status::OK();
207+
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template",
208+
WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_),
209+
WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec),
210+
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
211+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
384212
}
385213

386214
Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context,
@@ -417,51 +245,8 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha
417245
shader.AddInput("input", ShaderUsage::UseUniform);
418246
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
419247

420-
// Inputs are splits of the GQA output, split into num_total_seq_length_tiles rows.
421-
// This shader needs to add these splits across the row dimension to arrive at the final result. The column is head size wide.
422-
// The reduction achieves maximum parallelization by splitting this task first into tile_size columns that each workgroup is responsible for.
423-
// Then within each workgroup the task of summation over the num_total_seq_length_tile for the tile_size columns is further split in two ways.
424-
// First across the row dimension to have WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE rows.
425-
// Then across the column dimension where each thread is responsible for 1 column of the TILE_SIZE columns the workgroup is resposible for.
426-
shader.AdditionalImplementation() << "const TILE_SIZE = " << tile_size_ << ";\n";
427-
shader.AdditionalImplementation() << R"HELPER_FN(
428-
var<workgroup> tile_input: array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
429-
)HELPER_FN";
430-
431-
shader.MainFunctionBody() << R"MAIN_FN(
432-
let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE;
433-
let head_idx = u32(workgroup_idx / uniforms.num_head_size_tile);
434-
let in_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec;
435-
var value = output_value_t(0);
436-
let local_row = u32(local_idx / TILE_SIZE);
437-
let local_col = local_idx % TILE_SIZE;
438-
439-
if (head_size_offset + local_col < uniforms.head_size_vec) {
440-
for (var r = 0u; r < uniforms.num_total_seq_length_tile; r += TILE_SIZE) {
441-
if (r + local_row < uniforms.num_total_seq_length_tile) {
442-
value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col];
443-
}
444-
}
445-
}
446-
447-
tile_input[local_row][local_col] = value;
448-
workgroupBarrier();
449-
450-
if (head_idx >= uniforms.num_heads) {
451-
return;
452-
}
453-
454-
if (local_idx < TILE_SIZE && head_size_offset + local_idx < uniforms.head_size_vec) {
455-
value = output_value_t(0);
456-
for (var i = 0u; i < TILE_SIZE; i++) {
457-
value += tile_input[i][local_idx];
458-
}
459-
let output_id = head_idx * uniforms.head_size_vec + head_size_offset + local_idx;
460-
output[output_id] = value;
461-
}
462-
)MAIN_FN";
463-
464-
return Status::OK();
248+
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template",
249+
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_));
465250
}
466251

467252
Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param has_attention_bias
5+
#param tile_size
6+
#param tile_size_k_vec
7+
#param sub_tile_count
8+
9+
var<workgroup> tile_q: array<q_value_t, tile_size_k_vec>;
10+
var<workgroup> inner_qk_values: array<array<q_element_t, tile_size_k_vec>, tile_size>;
11+
var<workgroup> tile_qk: array<q_element_t, tile_size>;
12+
13+
#if has_attention_bias
14+
fn loadAttentionBias(idx: u32) -> q_element_t
15+
{
16+
return attention_bias[idx];
17+
}
18+
#else
19+
fn loadAttentionBias(idx: u32) -> q_element_t
20+
{
21+
return q_element_t(0);
22+
}
23+
#endif
24+
25+
$MAIN {
26+
let local_row = u32(local_idx / tile_size_k_vec);
27+
let local_col = local_idx % tile_size_k_vec;
28+
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
29+
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
30+
let q_offset = head_idx * uniforms.head_size_vec;
31+
var total_sequence_length = uniforms.total_sequence_length;
32+
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec;
33+
for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
34+
if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
35+
tile_q[local_idx] = q[q_offset + k + local_idx];
36+
}
37+
workgroupBarrier();
38+
let q_data = tile_q[local_col] * q_element_t(uniforms.alpha);
39+
if (k + local_col < uniforms.head_size_vec) {
40+
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
41+
if (total_seq_offset + row_offset + local_row < total_sequence_length) {
42+
inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data);
43+
}
44+
}
45+
}
46+
workgroupBarrier();
47+
}
48+
49+
if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length && head_idx < uniforms.num_heads) {
50+
var sum = q_element_t(0);
51+
for (var i = 0u; i < tile_size_k_vec; i++) {
52+
sum += inner_qk_values[local_idx][i];
53+
}
54+
55+
sum = sum + loadAttentionBias(head_idx * total_sequence_length + total_seq_offset + local_idx);
56+
tile_qk[local_idx] = sum;
57+
output[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum;
58+
}
59+
workgroupBarrier();
60+
61+
if (head_idx >= uniforms.num_heads) {
62+
return;
63+
}
64+
65+
if (local_idx == 0u) {
66+
// Calculate the max and sum in current split.
67+
var l_max = f32(-3.402823e+38f);
68+
var l_sum = f32(0);
69+
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
70+
l_max = max(l_max, f32(tile_qk[i]));
71+
}
72+
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
73+
l_sum += exp(f32(tile_qk[i]) - l_max);
74+
}
75+
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile;
76+
metadata[meta_offset] = metadata_value_t(l_max, l_sum);
77+
}
78+
}

0 commit comments

Comments
 (0)