diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 5583f296fae42..0af6d2d4e907a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -403,6 +403,372 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T return Status::OK(); } +Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_past_) { + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let kIdx = workgroup_id.x;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; + if (has_past_) { + shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" + << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" + << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" + << " }\n" + << "}\n" + << "else if (kIdx >= uniforms.past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" + << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" + << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" + << " // Assumes kv have BNSH layout.\n" + << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" + << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + int past_sequence_length, int total_sequence_length) { + + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + bool has_past = (past_sequence_length != 0); + CopyKVCacheProgram program{"CopyKVCache", components, has_past}; + if (has_past) { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } else { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } + + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, + {present_value, ProgramTensorMetadataDependency::Rank, components}}); + + program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) + .SetWorkgroupSize(1) + .CacheHint(std::to_string(components) + std::to_string(has_past)) + .AddUniformVariables({{static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}, + {static_cast(parameters.head_size/ components)}}); + + return context.RunProgram(program); +} + +Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(new_sequence_length)(total_sequence_length) + // + // Expectation is that present_key, and present_value contain past key and values since + // we are out of storage buffers a shader can have and both past/present cant be passed. + // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. + constexpr int vectorization_size = 4; + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("present_key", ShaderUsage::UseUniform); + shader.AddInput("present_value", ShaderUsage::UseUniform); + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8. + // TILE_SIZE is the number of groups sharing the k_tile. + // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when + // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE + // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu + // gpu limits. For Intel this TILE_SIZE will be 8. + shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" + << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" + << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" + << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" + << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" + << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" + << "const MIN_VALUE : q_element_t = -6504.0h;\n"; + + // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake + // GPU afterwhich workgroups will be unscheduled to make space for memory. + shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" + << "var max_tile : array; // 2 * 8 = 16\n" + << "var denom_tile : array; // 2 * 8 = 16\n" + << "var o_ratio : array; // 2 * 8 = 16\n"; + + shader.AdditionalImplementation() << R"HELPER_FN( + +fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var value = q[idx+offset]; + q_tile[slot][idx] = value; + } +} + +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + var value = present_key[idx+offset]; + k_tile[slot][idx] = value; + } +} + +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + v_tile[slot][idx] = present_value[idx+offset]; + } +} + +fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +{ + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { + qk_tile[q_row][k_col] = 0.0; + return; + } + let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; + qk_tile[q_row][k_col] = attention_bias[offset]; +} + +fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) + { + let value = o_tile[slot][idx]; + output[offset+idx] = value; + } +} + +fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) +{ + var sum:vec4 = q_value_t(0, 0, 0, 0); + // idx is not initialized to sg_id to ensure uniformity because the loop uses + // subgroupAdd and unused lanes need to be initialized with 0 for correctness. + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var result = q_value_t(0); + let sg_idx = idx+sg_id; + // QKV_HEAD_VECTORIZED_SIZE is divisible by the subgroup size this if check is not + // required. Hopefully the compiler sees the first half of this if statement and + // removes this if instruction. + if (QKV_HEAD_VECTORIZED_SIZE % sg_size == 0 || sg_idx < QKV_HEAD_VECTORIZED_SIZE) + { + result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; + } + sum += subgroupAdd(result); + } + + if (sg_id == 0) + { + let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; + let sqrt_dk = q_element_t(uniforms.alpha); + let value = single_sum * sqrt_dk; + qk_tile[q_idx][k_idx] += value; + } +} + +fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) +{ + var x = MIN_VALUE; + if (enabled){ + x = qk_tile[q_idx][sg_id]; + } + var max_value = subgroupMax(x); + max_value = max(max_tile[q_idx], max_value); + let sub = x - max_value; + var value:q_element_t = 0; + if (enabled) { + value = exp(sub); + } + let sum = subgroupAdd(value); + + // Compute lhs term of update di prime and the compute di prime. + let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); + var d = dleft + sum; + if (d == 0) + { + // Avoid division by zero by setting d to a really small value. + d = 0.0000001h; + } + qk_tile[q_idx][sg_id] = value / d; + if (sg_id == 0) + { + max_tile[q_idx] = max_value; + denom_tile[q_idx] = d; + o_ratio[q_idx] = dleft / d; + } +} + +fn computeO(q_idx: u32, sg_id:u32, enabled:bool) +{ + var attn = q_element_t(0); + if (enabled) + { + attn = qk_tile[q_idx][sg_id]; + } + for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) + { + let val = v_tile[sg_id][i]; + var intermediate = attn * val; + let sum = subgroupAdd(intermediate); + if (sg_id == 0) + { + let o_ratio = o_ratio[q_idx]; + let old_o = o_tile[q_idx][i]; + let new_o = ( o_ratio * old_o) + sum; + o_tile[q_idx][i] = new_o; + } + } +} + +)HELPER_FN"; + +// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) +// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. + shader.MainFunctionBody() << R"MAIN_FN( +let head_idx = workgroup_id.x; + +let wave_x = u32(local_id.x / 4); +let wave_y = u32(local_id.y / 4); +// It is always the case that 0 <= wave_id < TILE_SIZE +let wave_id:u32 = wave_x + wave_y * 4; + +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + wave_id; +let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; +if (q_idx_global_using_wave_valid) +{ + // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. + loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +if (sg_id == 0) +{ + max_tile[wave_id] = MIN_VALUE; +} + +for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) +{ + let k_idx_global = k_start+wave_id; + let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; + if (k_idx_global_using_wave_valid) { + // Leveraging the subgroup lanes for parallelism, load into slot wave_id + // K/V values from k_start+wave_id. + loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); + loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and sg_id, (k_start+sg_id). + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); + } + workgroupBarrier(); + if (k_idx_global_using_wave_valid) + { + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) + { + // Leveraging the subgroups for parallelism, compute dot product of QK. + // Because for the case of new_seq 1, there is a single query and context length of K + // we iterate over q and use the waves for K so that this step can use all the waves in + // in the workgroup. + // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to + // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. + computeDotProduct(q_idx, wave_id, sg_id, sg_size); + } + } + let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(wave_id, sg_id, wave_lane_valid); + computeO(wave_id, sg_id, wave_lane_valid); +} +workgroupBarrier(); +if (q_idx_global_using_wave_valid) +{ + writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +)MAIN_FN"; + + return Status::OK(); +} + +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); + + constexpr int subgroup_size = 16; + constexpr int tile_size = 16; + bool has_attention_bias = attention_bias != nullptr; + FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + std::string cache_hint = std::to_string(has_attention_bias) + + std::to_string(subgroup_size) + + std::to_string(tile_size) + + std::to_string(parameters.head_size) + + std::to_string(parameters.num_heads); + program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) + .SetWorkgroupSize(subgroup_size, subgroup_size) + .CacheHint(cache_hint) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(parameters.total_sequence_length)}, + {alpha}}); + + return context.RunProgram(program); +} + MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : WebGpuKernel(info) { int64_t num_heads = 0; @@ -457,6 +823,15 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); + if (parameters.batch_size == 1 && + bias == nullptr && + past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context); + } + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, parameters.sequence_length, parameters.head_size}); TensorShape q_new_shape(q_new_dims); diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 36803e3027b4c..147c31ac9445c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -98,6 +98,53 @@ class VxAttentionScoreProgram final : public Program { int tile_size_; }; +class CopyKVCacheProgram final : public Program { + public: + CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) + : Program{kernel_name}, components_(components), has_past_(has_past) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); + + private: + int components_; + bool has_past_; +}; + +class FlashAttentionProgram final : public Program { + public: + FlashAttentionProgram(const std::string& kernel_name, + bool has_attention_bias, + int subgroup_size, + int tile_size, + int qkv_head_size, + int qkv_num_heads) + : Program{kernel_name}, + has_attention_bias_(has_attention_bias), + subgroup_size_(subgroup_size), + tile_size_(tile_size), + qkv_head_size_(qkv_head_size), + qkv_num_heads_(qkv_num_heads) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}); + + private: + bool has_attention_bias_; + int subgroup_size_; + int tile_size_; + int qkv_head_size_; + int qkv_num_heads_; +}; + class MultiHeadAttention final : public WebGpuKernel { public: MultiHeadAttention(const OpKernelInfo& info); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5685494556248..2215cc2a7cd20 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -62,7 +62,9 @@ Status ShaderHelper::Init() { "fn main(@builtin(global_invocation_id) global_id : vec3,\n" " @builtin(workgroup_id) workgroup_id : vec3,\n" " @builtin(local_invocation_index) local_idx : u32,\n" - " @builtin(local_invocation_id) local_id : vec3"; + " @builtin(local_invocation_id) local_id : vec3,\n" + " @builtin(subgroup_invocation_id) sg_id : u32,\n" + " @builtin(subgroup_size) sg_size : u32"; if (!is_1d_dispatch) { body_ss_ << ",\n" " @builtin(num_workgroups) num_workgroups : vec3";