From 0a5d2129e22b2d397fd0f3eea095c00342f579f3 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Sun, 27 Oct 2024 10:47:35 -0700 Subject: [PATCH 01/43] Added attention_common.h --- .../webgpu/bert/attention_common.h | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/attention_common.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h new file mode 100644 index 0000000000000..119947d945de7 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +struct WebgpuAttentionParameters { + AttentionParameters(onnxruntime::contrib::AttentionParameters parameters): + is_gqa_parameters_(false), + batch_size_(parameters.parameters.batch_size), + sequence_length_(parameters.parameters.sequence_length), + kv_sequence_length_(parameters.parameters.kv_sequence_length), + past_sequence_length_(parameters.parameters.past_sequence_length), + total_sequence_length_(parameters.parameters.total_sequence_length), + max_sequence_length_(parameters.parameters.max_sequence_length), + input_hidden_size_(parameters.parameters.input_hidden_size), + hidden_size_(parameters.parameters.hidden_size), + head_size_(parameters.parameters.head_size), + v_hidden_size_(parameters.parameters.v_hidden_size), + v_head_size_(parameters.parameters.v_head_size), + num_heads_(parameters.parameters.num_heads), + is_unidirectional_(parameters.parameters.is_unidirectional), + past_present_share_buffer_(parameters.parameters.past_present_share_buffer), + do_rotary_(parameters.parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.parameters.mask_filter_value), + scale_(parameters.parameters.scale), + mask_type_(parameters.parameters.mask_type), + qkv_format_(parameters.parameters.qkv_format) { + } + + AttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters): + is_gqa_parameters_(true), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), + total_sequence_length_(parameters.total_sequence_length), + hidden_size_(parameters.hidden_size), + num_heads_(parameters.num_heads), + head_size_(parameters.head_size), + kv_hidden_size_(parameters.kv_hidden_size), + kv_num_heads_(parameters.kv_num_heads), + num_splits_(parameters.num_splits), + rotary_dim_(parameters.rotary_dim), + is_unidirectional_(parameters.is_unidirectional), + do_rotary_(parameters.do_rotary_), + rotary_interleaved_(parameters.rotary_interleaved_), + use_smooth_softmax_(parameters.use_smooth_softmax_), + mask_filter_value_(parameters.scale_), + softcap_(parameters.softcap_), + qkv_format_(parameters.qkv_format), + zeros_count_(parameters.zeros_count_), + zero_ptr_(parameters.zero_ptr_), + n_reps(parameters.num_heads / parameters.kv_num_heads) { + } + + boolean is_gqa_parameters_; + int batch_size_(0) + int sequence_length_(0) + int kv_sequence_length_(0) // input sequence length of K or V + int past_sequence_length_(0) // sequence length in past state of K or V + int total_sequence_length_(0) // total sequence length of K or V + int max_sequence_length_(0) // max sequence length from 4D mask + int input_hidden_size_(0) // first dimension of weights for input projection + int hidden_size_(0) // hidden size of Q or K + int head_size_(0) // hidden size per head of Q or K + int v_hidden_size_(0) // hidden size of V + int v_head_size_(0) // hidden size per head of V + int num_heads_(0) + int rotary_embedding_(0) + bool is_unidirectional_(false) + bool past_present_share_buffer_(false) + bool do_rotary_(false) + bool broadcast_attn_bias_dim_0_(false) + bool broadcast_attn_bias_dim_1_(false) + float mask_filter_value_; + float scale_; + bool use_tf32_(false); + // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters + // and not in onnxruntime::contrib::AttentionParameters + int seqlen_past_kv_cache_(0) // sequence length of past kv tensor + int seqlen_present_kv_cache_(0) // sequence length of present kv tensor + int kv_hidden_size_(0) + int kv_num_heads_(0) + int num_splits_(0) // number of splits for splitkv + int rotary_dim_(0) // rotary embedding dimension + int local_window_size_(0) + bool kv_share_buffer_(false) + bool is_packed_qkv_(false) + bool is_subsequent_prompt_(false) // indicates whether we have past context and seqlen > 1 + bool is_first_prompt_(false) // indicates whether this is first decoding step + bool do_rotary_(false) + bool rotary_interleaved_(false) + bool use_smooth_softmax_(false) + float scale_; + float softcap_; + int zeros_count_(0); + int* zero_ptr_(nullptr); + // Computed values + int n_reps(1); + AttentionMaskType mask_type_; + AttentionQkvFormat qkv_format_; + }; + +} +} +} From 5bfa0705363d3fd19fc43b17b5e554c1c92be5cb Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Sun, 27 Oct 2024 22:28:23 -0700 Subject: [PATCH 02/43] wip --- .../webgpu/bert/attention_common.h | 145 +++++++++--------- .../webgpu/bert/multihead_attention.cc | 39 ++--- .../webgpu/bert/multihead_attention.h | 89 +---------- 3 files changed, 95 insertions(+), 178 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 119947d945de7..e5612100e00ba 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -9,106 +9,103 @@ #include "core/providers/webgpu/webgpu_kernel.h" #include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { namespace webgpu { struct WebgpuAttentionParameters { - AttentionParameters(onnxruntime::contrib::AttentionParameters parameters): + WebgpuAttentionParameters(AttentionParameters parameters): is_gqa_parameters_(false), - batch_size_(parameters.parameters.batch_size), - sequence_length_(parameters.parameters.sequence_length), - kv_sequence_length_(parameters.parameters.kv_sequence_length), - past_sequence_length_(parameters.parameters.past_sequence_length), - total_sequence_length_(parameters.parameters.total_sequence_length), - max_sequence_length_(parameters.parameters.max_sequence_length), - input_hidden_size_(parameters.parameters.input_hidden_size), - hidden_size_(parameters.parameters.hidden_size), - head_size_(parameters.parameters.head_size), - v_hidden_size_(parameters.parameters.v_hidden_size), - v_head_size_(parameters.parameters.v_head_size), - num_heads_(parameters.parameters.num_heads), - is_unidirectional_(parameters.parameters.is_unidirectional), - past_present_share_buffer_(parameters.parameters.past_present_share_buffer), - do_rotary_(parameters.parameters.do_rotary), - broadcast_attn_bias_dim_0_(parameters.parameters.broadcast_attn_bias_dim_0), - broadcast_attn_bias_dim_1_(parameters.parameters.broadcast_attn_bias_dim_1), - mask_filter_value_(parameters.parameters.mask_filter_value), - scale_(parameters.parameters.scale), - mask_type_(parameters.parameters.mask_type), - qkv_format_(parameters.parameters.qkv_format) { + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.kv_sequence_length), + past_sequence_length_(parameters.past_sequence_length), + total_sequence_length_(parameters.total_sequence_length), + max_sequence_length_(parameters.max_sequence_length), + input_hidden_size_(parameters.input_hidden_size), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.v_hidden_size), + v_head_size_(parameters.v_head_size), + num_heads_(parameters.num_heads), + is_unidirectional_(parameters.is_unidirectional), + past_present_share_buffer_(parameters.past_present_share_buffer), + do_rotary_(parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.mask_filter_value), + scale_(parameters.scale), + mask_type_(parameters.mask_type), + qkv_format_(parameters.qkv_format) { } - AttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters): + WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters): is_gqa_parameters_(true), batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), - seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), - seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), total_sequence_length_(parameters.total_sequence_length), hidden_size_(parameters.hidden_size), - num_heads_(parameters.num_heads), head_size_(parameters.head_size), + num_heads_(parameters.num_heads), + do_rotary_(parameters.do_rotary), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), kv_hidden_size_(parameters.kv_hidden_size), kv_num_heads_(parameters.kv_num_heads), num_splits_(parameters.num_splits), rotary_dim_(parameters.rotary_dim), - is_unidirectional_(parameters.is_unidirectional), - do_rotary_(parameters.do_rotary_), - rotary_interleaved_(parameters.rotary_interleaved_), - use_smooth_softmax_(parameters.use_smooth_softmax_), - mask_filter_value_(parameters.scale_), - softcap_(parameters.softcap_), - qkv_format_(parameters.qkv_format), - zeros_count_(parameters.zeros_count_), - zero_ptr_(parameters.zero_ptr_), - n_reps(parameters.num_heads / parameters.kv_num_heads) { + rotary_interleaved_(parameters.rotary_interleaved), + use_smooth_softmax_(parameters.use_smooth_softmax), + softcap_(parameters.softcap), + zeros_count_(parameters.zeros_count), + zero_ptr_(parameters.zero_ptr), + n_reps(parameters.num_heads / parameters.kv_num_heads), + qkv_format_(parameters.qkv_format) { } - boolean is_gqa_parameters_; - int batch_size_(0) - int sequence_length_(0) - int kv_sequence_length_(0) // input sequence length of K or V - int past_sequence_length_(0) // sequence length in past state of K or V - int total_sequence_length_(0) // total sequence length of K or V - int max_sequence_length_(0) // max sequence length from 4D mask - int input_hidden_size_(0) // first dimension of weights for input projection - int hidden_size_(0) // hidden size of Q or K - int head_size_(0) // hidden size per head of Q or K - int v_hidden_size_(0) // hidden size of V - int v_head_size_(0) // hidden size per head of V - int num_heads_(0) - int rotary_embedding_(0) - bool is_unidirectional_(false) - bool past_present_share_buffer_(false) - bool do_rotary_(false) - bool broadcast_attn_bias_dim_0_(false) - bool broadcast_attn_bias_dim_1_(false) + bool is_gqa_parameters_; + int batch_size_ = 0; + int sequence_length_ = 0; + int kv_sequence_length_ = 0; // input sequence length of K or V + int past_sequence_length_ = 0; // sequence length in past state of K or V + int total_sequence_length_ = 0; // total sequence length of K or V + int max_sequence_length_ = 0; // max sequence length from 4D mask + int input_hidden_size_ = 0; // first dimension of weights for input projection + int hidden_size_ = 0; // hidden size of Q or K + int head_size_ = 0; // hidden size per head of Q or K + int v_hidden_size_ = 0; // hidden size of V + int v_head_size_ = 0; // hidden size per head of V + int num_heads_ = 0; + int rotary_embedding_ = 0; + bool is_unidirectional_ = false; + bool past_present_share_buffer_ = false; + bool do_rotary_ = false; + bool broadcast_attn_bias_dim_0_ = false; + bool broadcast_attn_bias_dim_1_ = false; float mask_filter_value_; float scale_; - bool use_tf32_(false); + bool use_tf32_ = false;; // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters // and not in onnxruntime::contrib::AttentionParameters - int seqlen_past_kv_cache_(0) // sequence length of past kv tensor - int seqlen_present_kv_cache_(0) // sequence length of present kv tensor - int kv_hidden_size_(0) - int kv_num_heads_(0) - int num_splits_(0) // number of splits for splitkv - int rotary_dim_(0) // rotary embedding dimension - int local_window_size_(0) - bool kv_share_buffer_(false) - bool is_packed_qkv_(false) - bool is_subsequent_prompt_(false) // indicates whether we have past context and seqlen > 1 - bool is_first_prompt_(false) // indicates whether this is first decoding step - bool do_rotary_(false) - bool rotary_interleaved_(false) - bool use_smooth_softmax_(false) - float scale_; + int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor + int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor + int kv_hidden_size_ = 0; + int kv_num_heads_ = 0; + int num_splits_ = 0; // number of splits for splitkv + int rotary_dim_ = 0; // rotary embedding dimension + int local_window_size_ = 0; + bool kv_share_buffer_ = false; + bool is_packed_qkv_ = false; + bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt_ = false; // indicates whether this is first decoding step + bool rotary_interleaved_ = false; + bool use_smooth_softmax_ = false; float softcap_; - int zeros_count_(0); - int* zero_ptr_(nullptr); + int zeros_count_ = 0;; + int* zero_ptr_ = nullptr; // Computed values - int n_reps(1); + int n_reps = 1; AttentionMaskType mask_type_; AttentionQkvFormat qkv_format_; }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 5583f296fae42..4a284c7989b2b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -441,47 +442,47 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); TensorShapeVector output_shape(3); - output_shape[0] = static_cast(parameters.batch_size); - output_shape[1] = static_cast(parameters.sequence_length); - output_shape[2] = static_cast(parameters.v_hidden_size); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.v_hidden_size_); Tensor* output = context.Output(0, output_shape); // If optional outputs aren't needed, present_key and present_value will be null std::vector present_dims{ - parameters.batch_size, - parameters.num_heads, - parameters.total_sequence_length, - parameters.head_size, + parameters.batch_size_, + parameters.num_heads_, + parameters.total_sequence_length_, + parameters.head_size_, }; TensorShape present_shape(present_dims); Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, - parameters.sequence_length, parameters.head_size}); + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); - if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, present_value, parameters, context); } - TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, - parameters.kv_sequence_length, parameters.head_size}); + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.head_size, key, bias, parameters.hidden_size, &K)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); - TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, - parameters.kv_sequence_length, parameters.v_head_size}); + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); TensorShape v_new_shape(v_new_dims); Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); // Compute the attention score and apply the score to V return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 36803e3027b4c..fd1c37356df23 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -7,6 +7,9 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/attention_base.h" namespace onnxruntime { namespace contrib { @@ -14,91 +17,7 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class TransferBSDToBNSHProgram final : public Program { - public: - TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, - {"batch_offset", ProgramUniformVariableDataType::Uint32}, - {"sequence_offset", ProgramUniformVariableDataType::Uint32}, - {"head_offset", ProgramUniformVariableDataType::Uint32}, - {"bias_offset", ProgramUniformVariableDataType::Uint32}); - - private: - bool has_bias_; -}; - -class AttentionProbsProgram final : public Program { - public: - AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"alpha", ProgramUniformVariableDataType::Float32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); - - WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); - - private: - bool feed_past_key_; - bool has_present_key_; - bool has_attention_bias_; - int tile_size_; - int components_; -}; - -class InPlaceSoftmaxProgram final : public Program { - public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, - {"d_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); - - private: - int work_group_size_; - int components_; -}; - -class VxAttentionScoreProgram final : public Program { - public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); - - WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); - - private: - bool feed_past_value_; - bool has_present_value_; - int tile_size_; -}; - -class MultiHeadAttention final : public WebGpuKernel { +class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { public: MultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; From e6615e9e949f6bf67d15bb6669e94f81c9a78430 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 28 Oct 2024 11:40:44 -0700 Subject: [PATCH 03/43] Fix compilation errors --- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index e5612100e00ba..7cdd222d132f7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -101,13 +101,13 @@ struct WebgpuAttentionParameters { bool is_first_prompt_ = false; // indicates whether this is first decoding step bool rotary_interleaved_ = false; bool use_smooth_softmax_ = false; - float softcap_; + float softcap_ = 0.0; int zeros_count_ = 0;; int* zero_ptr_ = nullptr; // Computed values int n_reps = 1; - AttentionMaskType mask_type_; - AttentionQkvFormat qkv_format_; + AttentionMaskType mask_type_ = MASK_NONE; + AttentionQkvFormat qkv_format_ = UNKNOWN; }; } From 449afb4d6db20d9fcac7bd2e073717e3c502f07d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 28 Oct 2024 15:05:26 -0700 Subject: [PATCH 04/43] lint --- .../webgpu/bert/attention_common.h | 112 +++++++++--------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 7cdd222d132f7..0ecb38cdc83fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -15,57 +15,55 @@ namespace contrib { namespace webgpu { struct WebgpuAttentionParameters { - WebgpuAttentionParameters(AttentionParameters parameters): - is_gqa_parameters_(false), - batch_size_(parameters.batch_size), - sequence_length_(parameters.sequence_length), - kv_sequence_length_(parameters.kv_sequence_length), - past_sequence_length_(parameters.past_sequence_length), - total_sequence_length_(parameters.total_sequence_length), - max_sequence_length_(parameters.max_sequence_length), - input_hidden_size_(parameters.input_hidden_size), - hidden_size_(parameters.hidden_size), - head_size_(parameters.head_size), - v_hidden_size_(parameters.v_hidden_size), - v_head_size_(parameters.v_head_size), - num_heads_(parameters.num_heads), - is_unidirectional_(parameters.is_unidirectional), - past_present_share_buffer_(parameters.past_present_share_buffer), - do_rotary_(parameters.do_rotary), - broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), - broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), - mask_filter_value_(parameters.mask_filter_value), - scale_(parameters.scale), - mask_type_(parameters.mask_type), - qkv_format_(parameters.qkv_format) { - } + WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.kv_sequence_length), + past_sequence_length_(parameters.past_sequence_length), + total_sequence_length_(parameters.total_sequence_length), + max_sequence_length_(parameters.max_sequence_length), + input_hidden_size_(parameters.input_hidden_size), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.v_hidden_size), + v_head_size_(parameters.v_head_size), + num_heads_(parameters.num_heads), + is_unidirectional_(parameters.is_unidirectional), + past_present_share_buffer_(parameters.past_present_share_buffer), + do_rotary_(parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.mask_filter_value), + scale_(parameters.scale), + mask_type_(parameters.mask_type), + qkv_format_(parameters.qkv_format) { + } - WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters): - is_gqa_parameters_(true), - batch_size_(parameters.batch_size), - sequence_length_(parameters.sequence_length), - total_sequence_length_(parameters.total_sequence_length), - hidden_size_(parameters.hidden_size), - head_size_(parameters.head_size), - num_heads_(parameters.num_heads), - do_rotary_(parameters.do_rotary), - seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), - seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), - kv_hidden_size_(parameters.kv_hidden_size), - kv_num_heads_(parameters.kv_num_heads), - num_splits_(parameters.num_splits), - rotary_dim_(parameters.rotary_dim), - rotary_interleaved_(parameters.rotary_interleaved), - use_smooth_softmax_(parameters.use_smooth_softmax), - softcap_(parameters.softcap), - zeros_count_(parameters.zeros_count), - zero_ptr_(parameters.zero_ptr), - n_reps(parameters.num_heads / parameters.kv_num_heads), - qkv_format_(parameters.qkv_format) { - } + WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + total_sequence_length_(parameters.total_sequence_length), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + num_heads_(parameters.num_heads), + do_rotary_(parameters.do_rotary), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), + kv_hidden_size_(parameters.kv_hidden_size), + kv_num_heads_(parameters.kv_num_heads), + num_splits_(parameters.num_splits), + rotary_dim_(parameters.rotary_dim), + rotary_interleaved_(parameters.rotary_interleaved), + use_smooth_softmax_(parameters.use_smooth_softmax), + softcap_(parameters.softcap), + zeros_count_(parameters.zeros_count), + zero_ptr_(parameters.zero_ptr), + n_reps(parameters.num_heads / parameters.kv_num_heads), + qkv_format_(parameters.qkv_format) { + } bool is_gqa_parameters_; - int batch_size_ = 0; + int batch_size_ = 0; int sequence_length_ = 0; int kv_sequence_length_ = 0; // input sequence length of K or V int past_sequence_length_ = 0; // sequence length in past state of K or V @@ -85,15 +83,16 @@ struct WebgpuAttentionParameters { bool broadcast_attn_bias_dim_1_ = false; float mask_filter_value_; float scale_; - bool use_tf32_ = false;; + bool use_tf32_ = false; + ; // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters // and not in onnxruntime::contrib::AttentionParameters int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor int kv_hidden_size_ = 0; int kv_num_heads_ = 0; - int num_splits_ = 0; // number of splits for splitkv - int rotary_dim_ = 0; // rotary embedding dimension + int num_splits_ = 0; // number of splits for splitkv + int rotary_dim_ = 0; // rotary embedding dimension int local_window_size_ = 0; bool kv_share_buffer_ = false; bool is_packed_qkv_ = false; @@ -102,14 +101,15 @@ struct WebgpuAttentionParameters { bool rotary_interleaved_ = false; bool use_smooth_softmax_ = false; float softcap_ = 0.0; - int zeros_count_ = 0;; + int zeros_count_ = 0; + ; int* zero_ptr_ = nullptr; // Computed values int n_reps = 1; AttentionMaskType mask_type_ = MASK_NONE; AttentionQkvFormat qkv_format_ = UNKNOWN; - }; +}; -} -} -} +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime From 8d104726dde3d3b69dc21a072e1279aa203952d4 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 12:07:27 -0700 Subject: [PATCH 05/43] Modified MultiHeadAttention to not derive from AttentionBase class --- onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index fd1c37356df23..226cbb7bc0254 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -7,9 +7,6 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_kernel.h" -#include "contrib_ops/webgpu/bert/attention.h" - -#include "contrib_ops/cpu/bert/attention_base.h" namespace onnxruntime { namespace contrib { @@ -17,7 +14,7 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { +class MultiHeadAttention final : public WebGpuKernel { public: MultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; From 4ea58d1e568ffbf58911b4d92bea08e066ab21e4 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 12:08:27 -0700 Subject: [PATCH 06/43] Uncomment GQA registration --- onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 4006006a76ba8..2e7ed5a16a2f0 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -42,7 +42,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, From 4bcf257abe87907f4c280737f85d4c5eb3493808 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 12:09:22 -0700 Subject: [PATCH 07/43] Moved TransferBSToBNSH and ApplyAttention declaration to attention_common.h from attention.h --- .../contrib_ops/webgpu/bert/attention_common.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 0ecb38cdc83fd..c5a4d59a0b24c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -81,8 +81,8 @@ struct WebgpuAttentionParameters { bool do_rotary_ = false; bool broadcast_attn_bias_dim_0_ = false; bool broadcast_attn_bias_dim_1_ = false; - float mask_filter_value_; - float scale_; + float mask_filter_value_ = -10000.0f; + float scale_ = 0.0f; bool use_tf32_ = false; ; // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters @@ -110,6 +110,14 @@ struct WebgpuAttentionParameters { AttentionQkvFormat qkv_format_ = UNKNOWN; }; +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr); + + } // namespace webgpu } // namespace contrib } // namespace onnxruntime From 5c5c934429cfe97d38c1b6eeedd5ae52c40d5101 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 16:39:02 -0700 Subject: [PATCH 08/43] Revert "Modified MultiHeadAttention to not derive from AttentionBase class" This reverts commit ba453039352f13b9086ee3e79305d460676ffc73. --- .../contrib_ops/webgpu/bert/multihead_attention.cc | 8 +------- .../contrib_ops/webgpu/bert/multihead_attention.h | 11 ++++------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 4a284c7989b2b..50a462325e412 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -405,13 +405,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T } MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : WebGpuKernel(info) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - scale_ = info.GetAttrOrDefault("scale", 0.0f); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + : WebGpuKernel(info), AttentionBase(info, false) { ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 226cbb7bc0254..d983236422c9e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -7,6 +7,9 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/attention_base.h" namespace onnxruntime { namespace contrib { @@ -14,16 +17,10 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class MultiHeadAttention final : public WebGpuKernel { +class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { public: MultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; - - protected: - int num_heads_; - float mask_filter_value_; - float scale_; - bool is_unidirectional_{false}; }; } // namespace webgpu From e7165469b6d09c366489bc846263589e6e866700 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 17:19:58 -0700 Subject: [PATCH 09/43] Converted CheckInput function to template to fix compiler/linker multiple definitions error. --- .../cpu/bert/group_query_attention_helper.h | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0bdee151d2173..19c3981f491fc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -11,18 +11,19 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query, return Status::OK(); } -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap, int max_threads_per_block) { From aba59e5aff948d1d4d9f53c25446fcf834959eff Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 18:22:32 -0700 Subject: [PATCH 10/43] lint --- .../contrib_ops/cpu/bert/group_query_attention_helper.h | 4 ++-- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 19c3981f491fc..4cc5a4228dc8c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -11,7 +11,7 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { -template +template Status CheckInputs(const T* query, const T* key, const T* value, @@ -266,7 +266,7 @@ Status CheckInputs(const T* query, return Status::OK(); } -template +template Status CheckInputs(const T* query, const T* key, const T* value, diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index c5a4d59a0b24c..d5ade10420804 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -117,7 +117,6 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr); - } // namespace webgpu } // namespace contrib } // namespace onnxruntime From 067ecd18df231df480cd922f30761a819d41f891 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 29 Oct 2024 23:49:21 -0700 Subject: [PATCH 11/43] Fixed conflicts. --- .../contrib_ops/webgpu/bert/attention.cc | 486 ++++++++++++++++++ .../contrib_ops/webgpu/bert/attention.h | 117 +++++ .../webgpu/bert/group_query_attention.cc | 98 ++++ .../webgpu/bert/group_query_attention.h | 53 ++ .../webgpu/bert/multihead_attention.cc | 384 +------------- 5 files changed, 757 insertions(+), 381 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/attention.h create mode 100644 onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc new file mode 100644 index 0000000000000..5b69f68f968df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -0,0 +1,486 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3); + ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +void InitVarStub(const Tensor* seqlens_k, const Tensor* total_seqlen_tensor, bool init_past_sequence_length, std::ostringstream& ss) { + if (seqlens_k != nullptr && total_seqlen_tensor != nullptr) { + ss << "let total_sequence_length_input = u32(total_seqlen_tensor[0]);\n"; + ss << "let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);\n"; + ss << "let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;\n"; + ss << "let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;\n"; + ss << "total_sequence_length = u32(seqlens_k[batch_idx]) + 1;\n"; + ss << "var past_sequence_length: u32 = 0;\n"; + ss << "if (is_first_prompt == false) {\n"; + ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; + ss << "}\n"; + } else { + if (init_past_sequence_length) { + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; + } + ss << "let present_sequence_length = total_sequence_length;\n"; + } +} + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + if (seqlen_k_ != nullptr) { + shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + } + if (total_seqlen_tensor_ != nullptr) { + shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + "let kv_head_idx = head_idx / uniforms.n_reps;\n" + "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + "let m = workgroup_id.y * TILE_SIZE;\n" + "let n = workgroup_id.x * TILE_SIZE;\n" + "let sequence_length = uniforms.M;\n" + "var total_sequence_length = uniforms.N;\n" + "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"; + std::ostringstream oss; + InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + shader.MainFunctionBody() << oss.str(); + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; + } + shader.MainFunctionBody() << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" + " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + " }\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " }\n"; + } + + if (has_present_key_) { + shader.MainFunctionBody() << " if (n + local_id.y < present_sequence_length) {\n" + << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, + const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) + : parameters.scale_; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + const int tile_size = 12; + const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components, seqlen_k, total_seqlen_tensor}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { + program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = parameters.head_size_ / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size)) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (seqlen_k_) { + shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + } + if (total_seqlen_tensor_) { + shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); + } + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let sequence_length = uniforms.sequence_length;\n" + << "var total_sequence_length = uniforms.total_sequence_length;\n"; + std::ostringstream oss; + InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + shader.MainFunctionBody() << oss.str() + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "total_sequence_length") << ";\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + if (seqlen_k_) { + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {\n" + << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" + << "}\n"; + } + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, + const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); + int work_group_size = 64; + const int total_sequence_length_comp = total_sequence_length / components; + if (total_sequence_length_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k, total_seqlen_tensor}; + if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { + program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .SetDispatchGroupSize(batch_size * num_heads, sequence_length, total_sequence_length) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(num_heads)}, + {static_cast(past_sequence_length)}, + {static_cast(sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + if (seqlen_k_) { + shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + } + if (total_seqlen_tensor_) { + shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let kv_head_idx = head_idx / uniforms.n_reps;\n" + << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.K;\n"; + std::ostringstream oss; + InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + shader.MainFunctionBody() << oss.str() + << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = head_idx * uniforms.N * uniforms.K + n;\n"; + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" + << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n" + << " tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" + << " }\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << " if (w + local_id.y < present_sequence_length) {\n" + << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + WebgpuAttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length, + const Tensor* seqlen_k, + const Tensor* total_seqlen_tensor) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, seqlen_k, total_seqlen_tensor}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { + program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size_)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(parameters.v_hidden_size_)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + ; + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + + const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, total_seqlen_tensor)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h new file mode 100644 index 0000000000000..f3c9573bb31df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; + const Tensor* seqlen_k_; + const Tensor* total_seqlen_tensor_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; + const Tensor* seqlen_k_; + const Tensor* total_seqlen_tensor_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; + const Tensor* seqlen_k_; + const Tensor* total_seqlen_tensor_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc new file mode 100644 index 0000000000000..406c7fb92358d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::group_query_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + GroupQueryAttention); + +Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* past_key = context.Input(3); + const Tensor* past_value = context.Input(4); + const Tensor* seqlen_k = context.Input(5); + const Tensor* total_seqlen_tensor = context.Input(6); + const Tensor* cos_cache = context.Input(7); + const Tensor* sin_cache = context.Input(8); + + GroupQueryAttentionParameters params; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶ms, + num_heads_, + kv_num_heads_, + seqlen_k, + total_seqlen_tensor, + scale_, + softcap_)); + WebgpuAttentionParameters parameters(params); + if (parameters.is_packed_qkv_) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); + } + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.v_hidden_size_); + Tensor* output = context.Output(0, output_shape); + const int present_kv_seqlen = parameters.seqlen_present_kv_cache_; + std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); + Tensor* present_key = context.Output(1, present_kv_shape); + Tensor* present_value = context.Output(2, present_kv_shape); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k, total_seqlen_tensor); + } + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, parameters.hidden_size_, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 2 * parameters.hidden_size_, &V)); + return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k, total_seqlen_tensor); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h new file mode 100644 index 0000000000000..04969dc778927 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class GroupQueryAttention final : public WebGpuKernel { + public: + GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + float softcap_; + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int local_window_size_; + + bool use_smooth_softmax_; + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 50a462325e412..424556c66bd9d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -26,384 +26,6 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedFloatTypes()), MultiHeadAttention); -Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("qkv_input", ShaderUsage::UseUniform); - const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); - - if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); - } - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" - << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" - << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; - if (has_bias_) { - shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; - } - shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; - if (has_bias_) { - shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; - } else { - shader.MainFunctionBody() << ";\n"; - } - - return Status::OK(); -} - -Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, - int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { - assert(input_tensor->Shape().GetDims().size() == 3); - assert(output_tensor->Shape().GetDims().size() == 4); - - uint32_t data_size = gsl::narrow(output_tensor->Shape().Size()); - const int batch_offset = num_heads * sequence_length * head_size; - const int sequence_offset = num_heads * head_size; - const int head_offset = head_size; - bool has_bias = bias != nullptr; - - TransferBSDToBNSHProgram program{has_bias}; - program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({{data_size}, - {static_cast(batch_offset)}, - {static_cast(sequence_offset)}, - {static_cast(head_offset)}, - {static_cast(bias_offset)}}); - - if (has_bias) { - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); - } - - return context.RunProgram(program); -}; - -Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (feed_past_key_) { - shader.AddInput("past_key", ShaderUsage::UseUniform); - } - if (has_attention_bias_) { - shader.AddInput("attention_bias", ShaderUsage::UseUniform); - } - - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (has_present_key_) { - shader.AddOutput("present_key", ShaderUsage::UseUniform); - } - - shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array;\n" - << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - "let headIdx = workgroup_id.z;\n" - "let m = workgroup_id.y * TILE_SIZE;\n" - "let n = workgroup_id.x * TILE_SIZE;\n" - "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; - - if (feed_past_key_ && has_present_key_) { - shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" - << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; - } else { - shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; - } - - if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; - } - - shader.MainFunctionBody() << "var value = f32_val_t(0);\n" - "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" - " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" - " }\n" - " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" - " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - - if (feed_past_key_ && has_present_key_) { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n" - " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - " } else {\n" - " tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" - " }\n"; - } else { - shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; - } - - if (has_present_key_) { - shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; - } - - shader.MainFunctionBody() << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" - << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" - << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" - << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; - - shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; - if (has_attention_bias_) { - shader.MainFunctionBody() << " + attention_bias[outputIdx]"; - } - shader.MainFunctionBody() << ";\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, - const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, - AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { - const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - - const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; - const bool has_present_key = output_count > 1 && past_key; - const bool has_attention_bias = attention_bias != nullptr; - constexpr int tile_size = 12; - const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); - - AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components}; - program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (feed_past_key) { - program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); - } - if (has_attention_bias) { - program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); - } - program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); - if (has_present_key) { - program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); - } - - const uint32_t vectorized_head_size = parameters.head_size / components; - program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(vectorized_head_size)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.num_heads)}, - {static_cast(alpha)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}}); - - return context.RunProgram(program); -} - -Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AdditionalImplementation() << "var thread_max: array;\n" - << "var thread_sum: array;\n" - << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" - << "}\n" - << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var max_value = f32(-3.402823e+38f);\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << " max_value = max(thread_max[i], max_value);\n" - << "}\n" - << "var sum_vector = f32_val_t(0);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" - << "}\n" - << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var sum: f32 = 0;\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << " sum += thread_sum[i]\n;" - << "}\n" - << "if (sum == 0) {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" - << " }\n" - << "} else {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " var f32input = f32_val_t(x[offset + i]);\n" - << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" - << " }\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { - const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); - int work_group_size = 64; - const int d_comp = d / components; - if (d_comp < work_group_size) { - work_group_size = 32; - } - const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; - - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; - program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(n) - .SetWorkgroupSize(work_group_size) - .AddUniformVariables({{static_cast(1.f / static_cast(d))}, - {static_cast(d_comp)}, - {static_cast(elementsPerThread)}}); - - return context.RunProgram(program); -} - -Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (feed_past_value_) { - shader.AddInput("past_value", ShaderUsage::UseUniform); - } - - shader.AddOutput("output", ShaderUsage::UseUniform); - if (has_present_value_) { - shader.AddOutput("present_value", ShaderUsage::UseUniform); - } - - shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array;\n"; - - shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; - - if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" - << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; - } else { - shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; - } - - if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; - } - - shader.MainFunctionBody() << "var value = probs_element_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" - << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" - << " }\n" - << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" - << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - - if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n" - << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" - << " } else {\n" - << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" - << " }\n"; - } else { - shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; - } - - if (has_present_value_) { - shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; - } - - shader.MainFunctionBody() << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" - << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" - << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" - << "if (m < uniforms.M && n < uniforms.N) {\n" - << " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " - << " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" - << " output[outputIdx] = value;\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, - const Tensor* probs, - const Tensor* V, - const Tensor* past_value, - Tensor* output, - Tensor* present_value, - AttentionParameters& parameters, - int past_sequence_length, - int total_sequence_length) { - const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; - const bool has_present_value = output_count > 1 && past_value != nullptr; - constexpr int tile_size = 12; - - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; - program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, - {V, ProgramTensorMetadataDependency::TypeAndRank}}); - if (feed_past_value) { - program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); - } - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); - if (has_present_value) { - program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); - } - - program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.v_head_size)}, - {static_cast(parameters.num_heads)}, - {static_cast(parameters.v_hidden_size)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}}); - ; - - return context.RunProgram(program); -} - -Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); - const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; - - const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, - parameters.sequence_length, total_sequence_length}); - const TensorShape probs_shape(probs_dims); - Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); - ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length)); - - ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); - - ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length)); - - return Status::OK(); -} - MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : WebGpuKernel(info), AttentionBase(info, false) { ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); @@ -429,12 +51,12 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); } - AttentionParameters parameters; + AttentionParameters params; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, - bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶ms, num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); - + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); From 53f1c78de74a438fbe93095b56ae7fbc7cddef87 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 30 Oct 2024 09:02:14 -0700 Subject: [PATCH 12/43] copying errors --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 5b69f68f968df..1e1120d6b54c6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -347,6 +347,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let m = global_id.y;\n" << "let n = global_id.x;\n" + << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; @@ -356,11 +357,11 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n\n"; + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; } if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = head_idx * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n"; } shader.MainFunctionBody() << "var value = probs_element_t(0);\n" @@ -379,13 +380,13 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n"; } else { shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n" - << " tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" + << " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" << " }\n"; } if (has_present_value_) { shader.MainFunctionBody() << " if (w + local_id.y < present_sequence_length) {\n" - << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];\n" + << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; } @@ -398,8 +399,6 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "}\n"; shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" << "if (m < uniforms.M && n < uniforms.N) {\n" << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" From f4dc9fc67ebaaa94b43e3b8d02f0fa72aa4c2911 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 31 Oct 2024 00:00:56 -0700 Subject: [PATCH 13/43] Fixed inplacesoftmax dispatch --- .../contrib_ops/webgpu/bert/attention.cc | 81 ++++++++++++------- .../contrib_ops/webgpu/bert/attention.h | 10 ++- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 1e1120d6b54c6..0bc834e1af93d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -111,26 +111,35 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - "let kv_head_idx = head_idx / uniforms.n_reps;\n" - "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" - "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - "let m = workgroup_id.y * TILE_SIZE;\n" - "let n = workgroup_id.x * TILE_SIZE;\n" - "let sequence_length = uniforms.M;\n" - "var total_sequence_length = uniforms.N;\n" - "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" - "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"; + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); shader.MainFunctionBody() << oss.str(); - - if (feed_past_key_ && has_present_key_) { - shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; - } - shader.MainFunctionBody() << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; - if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; + if (n_reps_ > 1) { + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let kv_head_idx = head_idx / uniforms.n_reps;\n" + << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";; + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; + } + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; + } + } else { + shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; + } + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n"; + } } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" @@ -196,7 +205,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, seqlen_k, total_seqlen_tensor}; + components, parameters.n_reps, seqlen_k, total_seqlen_tensor}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -310,13 +319,13 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(batch_size * num_heads, sequence_length, total_sequence_length) + .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, {static_cast(past_sequence_length)}, {static_cast(sequence_length)}, - {static_cast(total_sequence_length)}, + {static_cast(total_sequence_length_comp)}, {static_cast(elementsPerThread)}}); return context.RunProgram(program); @@ -343,8 +352,6 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var tileK: array;\n"; shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let kv_head_idx = head_idx / uniforms.n_reps;\n" - << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let m = global_id.y;\n" << "let n = global_id.x;\n" << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" @@ -352,16 +359,28 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); - shader.MainFunctionBody() << oss.str() - << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" - << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << oss.str(); + if (n_reps_ > 1) { + shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" + << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; + } - if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; - } + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n"; + } + } else { + shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n"; + } - if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n"; + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.K + n;\n"; + } } shader.MainFunctionBody() << "var value = probs_element_t(0);\n" @@ -423,7 +442,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, seqlen_k, total_seqlen_tensor}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, total_seqlen_tensor}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index f3c9573bb31df..17f7b9e9c5fcd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -58,6 +58,7 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; + int n_reps_; const Tensor* seqlen_k_; const Tensor* total_seqlen_tensor_; }; @@ -86,8 +87,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -108,6 +109,7 @@ class VxAttentionScoreProgram final : public Program { bool feed_past_value_; bool has_present_value_; int tile_size_; + int n_reps_; const Tensor* seqlen_k_; const Tensor* total_seqlen_tensor_; }; From 3d1af1c61358557a2a0a451de84816bbe8bfc764 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 1 Nov 2024 11:26:48 -0700 Subject: [PATCH 14/43] Initialize required parameter data --- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index d5ade10420804..230207ed26c1a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -42,9 +42,12 @@ struct WebgpuAttentionParameters { WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true), batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.sequence_length), total_sequence_length_(parameters.total_sequence_length), hidden_size_(parameters.hidden_size), head_size_(parameters.head_size), + v_hidden_size_(parameters.kv_hidden_size), + v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), num_heads_(parameters.num_heads), do_rotary_(parameters.do_rotary), seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), From 2eaeebc35af88623bd2f580a25cd22d61caeff8f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 1 Nov 2024 11:27:54 -0700 Subject: [PATCH 15/43] Map total_seqlen_tensor input to CPU --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 406c7fb92358d..5c1f032d976aa 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -23,7 +23,8 @@ ONNX_OPERATOR_KERNEL_EX( 1, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()), + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .InputMemoryType(OrtMemTypeCPUInput, 6), GroupQueryAttention); Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { From 9c828ccbf7bd8d139ceb102c9b401768d58a855f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 31 Oct 2024 08:21:36 -0700 Subject: [PATCH 16/43] Use uniforms variable name consistently to avoid confusion. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 8 ++++---- onnxruntime/contrib_ops/webgpu/bert/attention.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 0bc834e1af93d..e3b5bef6ef261 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -257,13 +257,13 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" << "let sequence_length = uniforms.sequence_length;\n" - << "var total_sequence_length = uniforms.total_sequence_length;\n"; + << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "total_sequence_length") << ";\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" @@ -295,7 +295,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << "}\n"; if (seqlen_k_) { - shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {\n" + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < total_sequence_length; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" << "}\n"; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 17f7b9e9c5fcd..ea0ad7e03fc54 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -75,7 +75,7 @@ class InPlaceSoftmaxProgram final : public Program { {"num_heads", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); private: From 26caa060bb07507f0c268b39f820b16e06210a40 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 31 Oct 2024 08:29:05 -0700 Subject: [PATCH 17/43] Keep InplaceSoftmax dispatch 3-dim. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index e3b5bef6ef261..bacd3dbfd533d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -201,7 +201,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; const bool has_present_key = output_count > 1 && past_key; const bool has_attention_bias = attention_bias != nullptr; - const int tile_size = 12; + constexpr int tile_size = 12; const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, @@ -319,7 +319,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(batch_size * num_heads * sequence_length) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, From 64b093f6f1ce21d5b55e6e440bbd005e16aa7323 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 31 Oct 2024 08:53:30 -0700 Subject: [PATCH 18/43] Formatting changes. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index bacd3dbfd533d..ea5ab3d091461 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -125,7 +125,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_head_idx = head_idx / uniforms.n_reps;\n" << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" - << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";; + << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; if (feed_past_key_ && has_present_key_) { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; } @@ -319,7 +319,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, From a8bd38bf55fb3de9a0920a6fcbf9b3747cc331e3 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 4 Nov 2024 11:25:17 -0800 Subject: [PATCH 19/43] Use total_seqlen_tensor input only to determin is_first_prompt. --- .../contrib_ops/webgpu/bert/attention.cc | 86 ++++++++----------- .../contrib_ops/webgpu/bert/attention.h | 26 +++--- .../webgpu/bert/attention_common.h | 5 +- .../webgpu/bert/group_query_attention.cc | 6 +- 4 files changed, 57 insertions(+), 66 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index ea5ab3d091461..a8459111b31b9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -68,22 +68,15 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(const Tensor* seqlens_k, const Tensor* total_seqlen_tensor, bool init_past_sequence_length, std::ostringstream& ss) { - if (seqlens_k != nullptr && total_seqlen_tensor != nullptr) { - ss << "let total_sequence_length_input = u32(total_seqlen_tensor[0]);\n"; - ss << "let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);\n"; - ss << "let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;\n"; - ss << "let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;\n"; +void InitVarStub(std::ostringstream& ss, const Tensor* seqlens_k) { + if (seqlens_k != nullptr) { ss << "total_sequence_length = u32(seqlens_k[batch_idx]) + 1;\n"; ss << "var past_sequence_length: u32 = 0;\n"; - ss << "if (is_first_prompt == false) {\n"; + ss << "if (uniforms.is_first_prompt != 0) {\n"; ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; ss << "}\n"; } else { - if (init_past_sequence_length) { - ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; - } - ss << "let present_sequence_length = total_sequence_length;\n"; + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; } } @@ -99,9 +92,6 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_ != nullptr) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_ != nullptr) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (has_present_key_) { shader.AddOutput("present_key", ShaderUsage::UseUniform); @@ -118,7 +108,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" @@ -163,7 +153,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_key_) { - shader.MainFunctionBody() << " if (n + local_id.y < present_sequence_length) {\n" + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n" << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; } @@ -194,7 +184,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, - const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -205,7 +195,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.n_reps, seqlen_k, total_seqlen_tensor}; + components, parameters.n_reps, seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -214,9 +204,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); if (has_present_key) { @@ -228,7 +217,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -237,7 +226,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(alpha)}, {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, - {static_cast(parameters.n_reps)}}) + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}, + {static_cast(parameters.is_first_prompt_)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -247,9 +238,6 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" @@ -259,7 +247,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.sequence_length;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" @@ -304,7 +292,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k, bool is_first_prompt) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = total_sequence_length / components; @@ -313,12 +301,12 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k, total_seqlen_tensor}; - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .CacheHint(std::to_string(work_group_size), is_first_prompt) .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -326,7 +314,8 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(past_sequence_length)}, {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, - {static_cast(elementsPerThread)}}); + {static_cast(elementsPerThread)}, + {static_cast(is_first_prompt)}}); return context.RunProgram(program); } @@ -340,9 +329,6 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::UseUniform); } - if (total_seqlen_tensor_) { - shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform); - } shader.AddOutput("output", ShaderUsage::UseUniform); if (has_present_value_) { shader.AddOutput("present_value", ShaderUsage::UseUniform); @@ -358,7 +344,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; - InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss); + InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" @@ -404,7 +390,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << " if (w + local_id.y < present_sequence_length) {\n" + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n" << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; } @@ -436,21 +422,19 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, - const Tensor* seqlen_k, - const Tensor* total_seqlen_tensor) { + const Tensor* seqlen_k) { const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, total_seqlen_tensor}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); } - if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) { - program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, - {total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); if (has_present_value) { @@ -460,6 +444,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -469,16 +454,17 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.v_hidden_size_)}, {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, - {static_cast(parameters.n_reps)}}) + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}, + {static_cast(parameters.is_first_prompt_)}}) .SetOverridableConstants({{static_cast(tile_size)}}); - ; return context.RunProgram(program); } Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; @@ -488,13 +474,13 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const TensorShape probs_shape(probs_dims); Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor)); + parameters, past_sequence_length, total_sequence_length, seqlen_k)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index ea0ad7e03fc54..8c6e27b9f9227 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -48,7 +48,9 @@ class AttentionProbsProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}); + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -60,13 +62,12 @@ class AttentionProbsProgram final : public Program { int components_; int n_reps_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -76,19 +77,19 @@ class InPlaceSoftmaxProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; int components_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), total_seqlen_tensor_(total_seqlen_tensor) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -101,7 +102,9 @@ class VxAttentionScoreProgram final : public Program { {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}); + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}, + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -111,7 +114,6 @@ class VxAttentionScoreProgram final : public Program { int tile_size_; int n_reps_; const Tensor* seqlen_k_; - const Tensor* total_seqlen_tensor_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 230207ed26c1a..c010ffd49b86b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -43,6 +43,7 @@ struct WebgpuAttentionParameters { batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), kv_sequence_length_(parameters.sequence_length), + past_sequence_length_(parameters.seqlen_past_kv_cache), total_sequence_length_(parameters.total_sequence_length), hidden_size_(parameters.hidden_size), head_size_(parameters.head_size), @@ -56,6 +57,8 @@ struct WebgpuAttentionParameters { kv_num_heads_(parameters.kv_num_heads), num_splits_(parameters.num_splits), rotary_dim_(parameters.rotary_dim), + is_subsequent_prompt_(parameters.is_subsequent_prompt), + is_first_prompt_(parameters.is_first_prompt), rotary_interleaved_(parameters.rotary_interleaved), use_smooth_softmax_(parameters.use_smooth_softmax), softcap_(parameters.softcap), @@ -118,7 +121,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* total_seqlen_tensor = nullptr); + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 5c1f032d976aa..1dbeeeda20164 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -60,7 +60,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); - output_shape[2] = static_cast(parameters.v_hidden_size_); + output_shape[2] = static_cast(parameters.hidden_size_); Tensor* output = context.Output(0, output_shape); const int present_kv_seqlen = parameters.seqlen_present_kv_cache_; std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); @@ -77,7 +77,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& parameters.kv_sequence_length_, parameters.head_size_}); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k, total_seqlen_tensor); + present_value, parameters, context, seqlen_k); } TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); @@ -91,7 +91,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 2 * parameters.hidden_size_, &V)); return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k, total_seqlen_tensor); + present_value, parameters, context, seqlen_k); } } // namespace webgpu From d613df4288fe4a6c35ddb54ea637a04076c6352b Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 4 Nov 2024 12:19:44 -0800 Subject: [PATCH 20/43] initialize is_packed_qkv_ --- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index c010ffd49b86b..286991cf26135 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -57,6 +57,7 @@ struct WebgpuAttentionParameters { kv_num_heads_(parameters.kv_num_heads), num_splits_(parameters.num_splits), rotary_dim_(parameters.rotary_dim), + is_packed_qkv_(parameters.is_packed_qkv), is_subsequent_prompt_(parameters.is_subsequent_prompt), is_first_prompt_(parameters.is_first_prompt), rotary_interleaved_(parameters.rotary_interleaved), From 0fedb9fad66bc80ccd30edbcf3c9d6baf23be2c3 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 6 Nov 2024 10:14:16 -0800 Subject: [PATCH 21/43] Handle past key/value and present key/value buffer sharing. --- .../contrib_ops/webgpu/bert/attention.cc | 52 +++++++++++-------- .../contrib_ops/webgpu/bert/attention.h | 10 ++-- .../webgpu/bert/group_query_attention.cc | 5 +- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index a8459111b31b9..7d7680b7ed897 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -68,9 +68,9 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(std::ostringstream& ss, const Tensor* seqlens_k) { - if (seqlens_k != nullptr) { - ss << "total_sequence_length = u32(seqlens_k[batch_idx]) + 1;\n"; +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { + if (seqlen_k != nullptr) { + ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; ss << "var past_sequence_length: u32 = 0;\n"; ss << "if (uniforms.is_first_prompt != 0) {\n"; ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; @@ -90,7 +90,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } if (seqlen_k_ != nullptr) { - shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (has_present_key_) { @@ -116,7 +116,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; - if (feed_past_key_ && has_present_key_) { + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; } if (has_present_key_) { @@ -124,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; - if (feed_past_key_ && has_present_key_) { + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; } if (has_present_key_) { @@ -140,9 +140,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - if (feed_past_key_ && has_present_key_) { + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" " }\n"; @@ -153,8 +153,12 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_key_) { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n" - << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; } @@ -188,14 +192,14 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; - const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; const bool has_present_key = output_count > 1 && past_key; const bool has_attention_bias = attention_bias != nullptr; constexpr int tile_size = 12; const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.n_reps, seqlen_k}; + components, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -236,7 +240,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { if (seqlen_k_) { - shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AdditionalImplementation() << "var thread_max: array;\n" @@ -327,7 +331,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("past_value", ShaderUsage::UseUniform); } if (seqlen_k_) { - shader.AddInput("seqlens_k", ShaderUsage::UseUniform); + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform); if (has_present_value_) { @@ -351,7 +355,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if (feed_past_value_ && has_present_value_) { + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; } @@ -360,7 +364,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if (feed_past_value_ && has_present_value_) { + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n"; } @@ -377,9 +381,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - if (feed_past_value_ && has_present_value_) { + if ((feed_past_value_ && has_present_value_) && past_present_share_buffer_) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" - << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" << " }\n"; @@ -390,8 +394,12 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n" - << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; } @@ -423,11 +431,11 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int int past_sequence_length, int total_sequence_length, const Tensor* seqlen_k) { - const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 8c6e27b9f9227..ee3aa1957cdd9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k) { + bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -62,6 +62,7 @@ class AttentionProbsProgram final : public Program { int components_; int n_reps_; const Tensor* seqlen_k_; + bool past_present_share_buffer_; }; class InPlaceSoftmaxProgram final : public Program { @@ -88,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -114,6 +115,7 @@ class VxAttentionScoreProgram final : public Program { int tile_size_; int n_reps_; const Tensor* seqlen_k_; + bool past_present_share_buffer_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1dbeeeda20164..f6cdadac40481 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -24,7 +24,9 @@ ONNX_OPERATOR_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedFloatTypes()) - .InputMemoryType(OrtMemTypeCPUInput, 6), + .MayInplace(3, 1) + .MayInplace(4, 2) + .InputMemoryType(OrtMemTypeCPUInput, 6), GroupQueryAttention); Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { @@ -66,6 +68,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); + parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); From 993140b275b6307f1c0fa25ac0c6ac67bbba73d1 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 6 Nov 2024 10:58:23 -0800 Subject: [PATCH 22/43] lint --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 6 +++--- .../contrib_ops/webgpu/bert/group_query_attention.cc | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 7d7680b7ed897..980bd7dcbe2c3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -124,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; - if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; } if (has_present_key_) { @@ -142,7 +142,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" " }\n"; @@ -431,7 +431,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int int past_sequence_length, int total_sequence_length, const Tensor* seqlen_k) { - const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f6cdadac40481..70af3ea5a84bc 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -24,9 +24,9 @@ ONNX_OPERATOR_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedFloatTypes()) - .MayInplace(3, 1) - .MayInplace(4, 2) - .InputMemoryType(OrtMemTypeCPUInput, 6), + .MayInplace(3, 1) + .MayInplace(4, 2) + .InputMemoryType(OrtMemTypeCPUInput, 6), GroupQueryAttention); Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { From 7502493afd9365c97dc9e61bc48e432df38dee35 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 6 Nov 2024 21:05:19 -0800 Subject: [PATCH 23/43] Added past_present_share_buffer to the hint. typo --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 6 +++--- .../contrib_ops/webgpu/bert/group_query_attention.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 980bd7dcbe2c3..2d552150c8284 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -221,7 +221,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -381,7 +381,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - if ((feed_past_value_ && has_present_value_) && past_present_share_buffer_) { + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" @@ -452,7 +452,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 70af3ea5a84bc..082666c11161e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -68,7 +68,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); - parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); From 5f1fdaea667813ac3577bfb1fb3ea7562990cdc0 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 12 Nov 2024 20:00:00 -0800 Subject: [PATCH 24/43] past_present_share_buffer related changes. --- .../contrib_ops/webgpu/bert/attention.cc | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 2d552150c8284..8bd0a88c89e5d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -116,16 +116,20 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; - if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + if (feed_past_key_ && has_present_key_) { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; - if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + if (feed_past_key_ && has_present_key_ || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n"; @@ -154,9 +158,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_key_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; @@ -355,8 +359,10 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + if (feed_past_value_ && has_present_value_) { shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; } if (has_present_value_) { @@ -364,8 +370,10 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + if (feed_past_value_ && has_present_value_) { shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; } if (has_present_value_) { @@ -395,9 +403,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_value_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; @@ -475,7 +483,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; +const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length}); From 6d2bd68f0325db4332638832b2ef5483af47758d Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Nov 2024 00:31:47 -0800 Subject: [PATCH 25/43] lint --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 8bd0a88c89e5d..35c6c6ef5228a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -147,9 +147,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" - " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" - " }\n"; + " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + " }\n"; } else { shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" @@ -483,7 +483,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; -const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; + const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length}); From 82a005de6801c827e2ca3053d31ec1b961a39488 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Nov 2024 15:37:00 -0800 Subject: [PATCH 26/43] Fix integer division --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 35c6c6ef5228a..f6966adb57a99 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -147,9 +147,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" - " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" - " }\n"; + << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " }\n"; } else { shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" @@ -220,7 +220,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); } - const uint32_t vectorized_head_size = parameters.head_size_ / components; + const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) @@ -303,7 +303,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso const Tensor* seqlen_k, bool is_first_prompt) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; - const int total_sequence_length_comp = total_sequence_length / components; + const int total_sequence_length_comp = (total_sequence_length + components -1) / components; if (total_sequence_length_comp < work_group_size) { work_group_size = 32; } From fd9409fc11e25dba6a3b16d7fd8b4f3f1686c9d1 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Nov 2024 15:38:20 -0800 Subject: [PATCH 27/43] Updated hints --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index f6966adb57a99..f314056e6eb73 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -225,7 +225,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_, past_key != nullptr, total_sequence_length, seqlen_k != nullptr) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -460,7 +460,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_) + .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_, past_value != nullptr, total_sequence_length, seqlen_k != nullptr) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, From 15c96b3d998a9a7b4bf636daf1777cb18bd0104b Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Nov 2024 15:39:42 -0800 Subject: [PATCH 28/43] match jsep code --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index f314056e6eb73..4bce49e36f863 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -158,9 +158,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_key_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"; } shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; @@ -403,9 +403,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_value_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"; } shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; From 72601d1e81d7cbc9a9a9ac769a1b7ecbc3ada973 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Nov 2024 16:40:27 -0800 Subject: [PATCH 29/43] Fixed a minor issue --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 4bce49e36f863..d778a27f1790f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -126,7 +126,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; - if (feed_past_key_ && has_present_key_ || past_present_share_buffer_) { + if (feed_past_key_ && has_present_key_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; } else if (past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; From 65495b6bfd840075ccfd3a8592156d3668de923d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 13 Nov 2024 16:52:39 -0800 Subject: [PATCH 30/43] lint --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index d778a27f1790f..8b86dfc9bb553 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -220,7 +220,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); } - const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; + const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) @@ -303,7 +303,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso const Tensor* seqlen_k, bool is_first_prompt) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; - const int total_sequence_length_comp = (total_sequence_length + components -1) / components; + const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; if (total_sequence_length_comp < work_group_size) { work_group_size = 32; } From 63f20ed317a962556b73457935e3fae489e874a5 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 15 Nov 2024 16:48:09 -0800 Subject: [PATCH 31/43] Fix a bug using total_sequence_length instead of uniform.total_sequence_length_comp --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 8b86dfc9bb553..5bb823937d9c0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -291,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << "}\n"; if (seqlen_k_) { - shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < total_sequence_length; total_seq_id++) {\n" + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" << "}\n"; } From 0102206e8dd236ceb3ba3de20970744479c93f35 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 15 Nov 2024 16:48:29 -0800 Subject: [PATCH 32/43] Revert "match jsep code" This reverts commit 15c96b3d998a9a7b4bf636daf1777cb18bd0104b. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 5bb823937d9c0..ee818294474ba 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -158,9 +158,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_key_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; @@ -403,9 +403,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_value_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; From 71ed10c1f69eab7fb10e9f067424c98e53b1d415 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Sun, 17 Nov 2024 12:11:53 -0800 Subject: [PATCH 33/43] Removed is_first_prompt from uniforms. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index ee818294474ba..e1abb9cec6460 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -225,7 +225,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_, past_key != nullptr, total_sequence_length, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, past_key != nullptr, seqlen_k != nullptr) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -314,7 +314,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(std::to_string(work_group_size), is_first_prompt) + .CacheHint(std::to_string(work_group_size)) .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -460,7 +460,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_, past_value != nullptr, total_sequence_length, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, past_value != nullptr, seqlen_k != nullptr) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, From 9c08c8219c2023d5d71d6ff5bd0c65ef38649005 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Sun, 17 Nov 2024 23:59:56 -0800 Subject: [PATCH 34/43] Updated hints --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index e1abb9cec6460..20e45f768017a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -225,7 +225,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, past_key != nullptr, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -460,7 +460,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, past_value != nullptr, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -483,7 +483,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length}); From eb5d7b4eed471363cbecc0be72f4e842215d4cf8 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 18 Nov 2024 00:01:00 -0800 Subject: [PATCH 35/43] Use kv_num_heads instead num_heads for key/value input shape conversion. --- .../webgpu/bert/group_query_attention.cc | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 082666c11161e..09a165cf66907 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -64,8 +64,13 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& output_shape[1] = static_cast(parameters.sequence_length_); output_shape[2] = static_cast(parameters.hidden_size_); Tensor* output = context.Output(0, output_shape); - const int present_kv_seqlen = parameters.seqlen_present_kv_cache_; - std::vector present_kv_shape({static_cast(parameters.batch_size_), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(parameters.head_size_)}); + std::vector present_dims{ + parameters.batch_size_, + kv_num_heads_, + parameters.seqlen_present_kv_cache_, + parameters.head_size_ + }; + std::vector present_kv_shape(present_dims); Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); @@ -76,23 +81,24 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH( context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); - TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, - parameters.kv_sequence_length_, parameters.head_size_}); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, present_value, parameters, context, seqlen_k); } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, - parameters.head_size_, key, nullptr, parameters.hidden_size_, &K)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &K)); - TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_}); TensorShape v_new_shape(v_new_dims); Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, - parameters.v_head_size_, value, nullptr, 2 * parameters.hidden_size_, &V)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &V)); return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, present_value, parameters, context, seqlen_k); } From 7a2d3b6b134e3564dff68f6bb704870f4860705b Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 18 Nov 2024 08:42:52 -0800 Subject: [PATCH 36/43] lint --- .../contrib_ops/webgpu/bert/group_query_attention.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 09a165cf66907..31c8af9b4f922 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -65,11 +65,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& output_shape[2] = static_cast(parameters.hidden_size_); Tensor* output = context.Output(0, output_shape); std::vector present_dims{ - parameters.batch_size_, - kv_num_heads_, - parameters.seqlen_present_kv_cache_, - parameters.head_size_ - }; + parameters.batch_size_, + kv_num_heads_, + parameters.seqlen_present_kv_cache_, + parameters.head_size_}; std::vector present_kv_shape(present_dims); Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); From a48d782ca063fd87de73ab0a3725c7bebcccc0bf Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 19 Nov 2024 19:28:07 -0800 Subject: [PATCH 37/43] changed variable name --- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 286991cf26135..d1316d5a932d1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -15,7 +15,7 @@ namespace contrib { namespace webgpu { struct WebgpuAttentionParameters { - WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false), + WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false), batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), kv_sequence_length_(parameters.kv_sequence_length), @@ -39,7 +39,7 @@ struct WebgpuAttentionParameters { qkv_format_(parameters.qkv_format) { } - WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true), + WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true), batch_size_(parameters.batch_size), sequence_length_(parameters.sequence_length), kv_sequence_length_(parameters.sequence_length), @@ -69,7 +69,7 @@ struct WebgpuAttentionParameters { qkv_format_(parameters.qkv_format) { } - bool is_gqa_parameters_; + bool is_gqa_; int batch_size_ = 0; int sequence_length_ = 0; int kv_sequence_length_ = 0; // input sequence length of K or V From 4334b396a355bbe3f1d6326983504cd1eaf7a802 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Nov 2024 08:31:53 -0800 Subject: [PATCH 38/43] Removed is_first_prompt from uniforms, used in a condition generating shader code and added to hint. --- .../contrib_ops/webgpu/bert/attention.cc | 44 ++++++++----------- .../contrib_ops/webgpu/bert/attention.h | 24 +++++----- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 20e45f768017a..891aa3a425b70 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -68,13 +68,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) { if (seqlen_k != nullptr) { ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; - ss << "var past_sequence_length: u32 = 0;\n"; - ss << "if (uniforms.is_first_prompt != 0) {\n"; - ss << " past_sequence_length = total_sequence_length - sequence_length;\n"; - ss << "}\n"; + ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n"; } else { ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; } @@ -108,7 +105,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" @@ -122,7 +119,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; @@ -132,7 +129,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } } @@ -203,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -225,7 +222,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -235,8 +232,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, - {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.n_reps)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -255,7 +251,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.sequence_length;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" @@ -309,12 +305,12 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(std::to_string(work_group_size)) + .CacheHint(work_group_size, is_first_prompt) .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -322,8 +318,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(past_sequence_length)}, {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, - {static_cast(elementsPerThread)}, - {static_cast(is_first_prompt)}}); + {static_cast(elementsPerThread)}}); return context.RunProgram(program); } @@ -352,7 +347,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, seqlen_k_, is_first_prompt_); shader.MainFunctionBody() << oss.str(); if (n_reps_ > 1) { shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" @@ -366,7 +361,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; } } else { shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; @@ -377,7 +372,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; } } @@ -407,7 +402,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } - shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.present_sequence_length] = tileK[idx];\n" << " }\n"; } @@ -443,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { @@ -460,7 +455,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -471,8 +466,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, - {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.n_reps)}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index ee3aa1957cdd9..03279fffbc3ef 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -49,8 +49,7 @@ class AttentionProbsProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"n_reps", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -63,12 +62,13 @@ class AttentionProbsProgram final : public Program { int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; + bool is_first_prompt_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -78,19 +78,19 @@ class InPlaceSoftmaxProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; int components_; const Tensor* seqlen_k_; + bool is_first_prompt_; }; class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -104,8 +104,7 @@ class VxAttentionScoreProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"n_reps", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -116,6 +115,7 @@ class VxAttentionScoreProgram final : public Program { int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; + bool is_first_prompt_; }; } // namespace webgpu From d53d7ef612a68d2cdeb125fa2a3e0de50e7224a0 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Nov 2024 14:48:38 -0800 Subject: [PATCH 39/43] error --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 891aa3a425b70..ea8aa95614b40 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -402,7 +402,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } - shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.present_sequence_length] = tileK[idx];\n" + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; } From 5dc95c8359b76fc0d2c09ba3a62004e821029fb5 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 21 Nov 2024 00:07:49 -0800 Subject: [PATCH 40/43] initialize scale --- onnxruntime/contrib_ops/webgpu/bert/attention_common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index d1316d5a932d1..b7137ef0aec3a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -51,6 +51,7 @@ struct WebgpuAttentionParameters { v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), num_heads_(parameters.num_heads), do_rotary_(parameters.do_rotary), + scale_(parameters.scale), seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), kv_hidden_size_(parameters.kv_hidden_size), From e448b1aa76c11e24d5c5c795a5116eb44335cd7d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 22 Nov 2024 09:06:37 -0800 Subject: [PATCH 41/43] Calculate output chunk size based on whether the kernel is GQA or not. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 11 ++++++----- onnxruntime/contrib_ops/webgpu/bert/attention.h | 10 ++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index ea8aa95614b40..3862510699646 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -172,7 +172,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "}\n"; shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" + << " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n" << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; @@ -200,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -416,8 +416,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" << "if (m < uniforms.M && n < uniforms.N) {\n" - << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " - << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" + << " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n" + << " let outputIdx = batch_idx * uniforms.M * tmp + " + << " m * tmp + head_idx * uniforms.N + n;\n" << " output[outputIdx] = value;\n" << "}\n"; @@ -438,7 +439,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 03279fffbc3ef..350f2387920f0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -63,6 +63,7 @@ class AttentionProbsProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; + bool is_gqa_; }; class InPlaceSoftmaxProgram final : public Program { @@ -89,8 +90,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -116,6 +117,7 @@ class VxAttentionScoreProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; + bool is_gqa_; }; } // namespace webgpu From 60af2f52c1a9d6ce1edf8d4dd828b2cc1689ef92 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 22 Nov 2024 23:16:15 -0800 Subject: [PATCH 42/43] Revert "Calculate output chunk size based on whether the kernel is GQA or not." This reverts commit e448b1aa76c11e24d5c5c795a5116eb44335cd7d. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 11 +++++------ onnxruntime/contrib_ops/webgpu/bert/attention.h | 10 ++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 3862510699646..ea8aa95614b40 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -172,7 +172,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "}\n"; shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; @@ -200,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -416,9 +416,8 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" << "if (m < uniforms.M && n < uniforms.N) {\n" - << " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n" - << " let outputIdx = batch_idx * uniforms.M * tmp + " - << " m * tmp + head_idx * uniforms.N + n;\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" << " output[outputIdx] = value;\n" << "}\n"; @@ -439,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 350f2387920f0..03279fffbc3ef 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_gqa_; }; class InPlaceSoftmaxProgram final : public Program { @@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_gqa_; }; } // namespace webgpu From 47e6f525682478ab3ff4fcb2287dbb57b38b7fd8 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 22 Nov 2024 23:45:48 -0800 Subject: [PATCH 43/43] Bug fix --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index ea8aa95614b40..089cde1669385 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -462,7 +462,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.v_head_size_)}, {static_cast(parameters.num_heads_)}, {static_cast(parameters.head_size_)}, - {static_cast(parameters.v_hidden_size_)}, + {static_cast(parameters.v_hidden_size_ * parameters.n_reps)}, {static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},