-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] Support GroupQueryAttention #22658
Open
satyajandhyala
wants to merge
44
commits into
main
Choose a base branch
from
sajandhy/webgpu-ep-gqa-new
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+962
−517
Open
Changes from 29 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
0a5d212
Added attention_common.h
satyajandhyala 5bfa070
wip
satyajandhyala e6615e9
Fix compilation errors
satyajandhyala 449afb4
lint
satyajandhyala 8d10472
Modified MultiHeadAttention to not derive from AttentionBase class
satyajandhyala 4ea58d1
Uncomment GQA registration
satyajandhyala 4bcf257
Moved TransferBSToBNSH and ApplyAttention declaration to attention_co…
satyajandhyala 5c5c934
Revert "Modified MultiHeadAttention to not derive from AttentionBase …
satyajandhyala e716546
Converted CheckInput function to template to fix compiler/linker mult…
satyajandhyala aba59e5
lint
satyajandhyala 067ecd1
Fixed conflicts.
satyajandhyala 53f1c78
copying errors
satyajandhyala f4dc9fc
Fixed inplacesoftmax dispatch
satyajandhyala 3d1af1c
Initialize required parameter data
satyajandhyala 2eaeebc
Map total_seqlen_tensor input to CPU
satyajandhyala 9c828cc
Use uniforms variable name consistently to avoid confusion.
satyajandhyala 26caa06
Keep InplaceSoftmax dispatch 3-dim.
satyajandhyala 64b093f
Formatting changes.
satyajandhyala a8bd38b
Use total_seqlen_tensor input only to determin is_first_prompt.
satyajandhyala d613df4
initialize is_packed_qkv_
satyajandhyala 0fedb9f
Handle past key/value and present key/value buffer sharing.
satyajandhyala 993140b
lint
satyajandhyala 7502493
Added past_present_share_buffer to the hint. typo
satyajandhyala 5f1fdae
past_present_share_buffer related changes.
satyajandhyala 6d2bd68
lint
satyajandhyala 82a005d
Fix integer division
satyajandhyala fd9409f
Updated hints
satyajandhyala 15c96b3
match jsep code
satyajandhyala 72601d1
Fixed a minor issue
satyajandhyala 65495b6
lint
satyajandhyala 63f20ed
Fix a bug using total_sequence_length instead of uniform.total_sequen…
satyajandhyala 0102206
Revert "match jsep code"
satyajandhyala 71ed10c
Removed is_first_prompt from uniforms.
satyajandhyala 9c08c82
Updated hints
satyajandhyala eb5d7b4
Use kv_num_heads instead num_heads for key/value input shape conversion.
satyajandhyala 7a2d3b6
lint
satyajandhyala 664022f
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala a48d782
changed variable name
satyajandhyala 4334b39
Removed is_first_prompt from uniforms, used in a condition generating…
satyajandhyala d53d7ef
error
satyajandhyala 5dc95c8
initialize scale
satyajandhyala e448b1a
Calculate output chunk size based on whether the kernel is GQA or not.
satyajandhyala 60af2f5
Revert "Calculate output chunk size based on whether the kernel is GQ…
satyajandhyala 47e6f52
Bug fix
satyajandhyala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/compute_context.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "contrib_ops/webgpu/bert/attention_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
using namespace onnxruntime::webgpu; | ||
|
||
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> { | ||
public: | ||
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, | ||
{"batch_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"sequence_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"head_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"bias_offset", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
bool has_bias_; | ||
}; | ||
|
||
class AttentionProbsProgram final : public Program<AttentionProbsProgram> { | ||
public: | ||
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, | ||
bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) | ||
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||
{"K", ProgramUniformVariableDataType::Uint32}, | ||
{"N", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"head_size", ProgramUniformVariableDataType::Uint32}, | ||
{"alpha", ProgramUniformVariableDataType::Float32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"n_reps", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); | ||
|
||
private: | ||
bool feed_past_key_; | ||
bool has_present_key_; | ||
bool has_attention_bias_; | ||
int tile_size_; | ||
int components_; | ||
int n_reps_; | ||
const Tensor* seqlen_k_; | ||
bool past_present_share_buffer_; | ||
}; | ||
|
||
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> { | ||
public: | ||
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) | ||
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, | ||
{"elements_per_thread", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
int work_group_size_; | ||
int components_; | ||
const Tensor* seqlen_k_; | ||
}; | ||
|
||
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> { | ||
public: | ||
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) | ||
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||
{"K", ProgramUniformVariableDataType::Uint32}, | ||
{"N", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"head_size", ProgramUniformVariableDataType::Uint32}, | ||
{"v_hidden_size", ProgramUniformVariableDataType::Uint32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"n_reps", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); | ||
|
||
private: | ||
bool feed_past_value_; | ||
bool has_present_value_; | ||
int tile_size_; | ||
int n_reps_; | ||
const Tensor* seqlen_k_; | ||
bool past_present_share_buffer_; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/compute_context.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "contrib_ops/webgpu/bert/attention_common.h" | ||
|
||
#include "contrib_ops/cpu/bert/attention_common.h" | ||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
struct WebgpuAttentionParameters { | ||
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false), | ||
batch_size_(parameters.batch_size), | ||
sequence_length_(parameters.sequence_length), | ||
kv_sequence_length_(parameters.kv_sequence_length), | ||
past_sequence_length_(parameters.past_sequence_length), | ||
total_sequence_length_(parameters.total_sequence_length), | ||
max_sequence_length_(parameters.max_sequence_length), | ||
input_hidden_size_(parameters.input_hidden_size), | ||
hidden_size_(parameters.hidden_size), | ||
head_size_(parameters.head_size), | ||
v_hidden_size_(parameters.v_hidden_size), | ||
v_head_size_(parameters.v_head_size), | ||
num_heads_(parameters.num_heads), | ||
is_unidirectional_(parameters.is_unidirectional), | ||
past_present_share_buffer_(parameters.past_present_share_buffer), | ||
do_rotary_(parameters.do_rotary), | ||
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), | ||
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), | ||
mask_filter_value_(parameters.mask_filter_value), | ||
scale_(parameters.scale), | ||
mask_type_(parameters.mask_type), | ||
qkv_format_(parameters.qkv_format) { | ||
} | ||
|
||
WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true), | ||
batch_size_(parameters.batch_size), | ||
sequence_length_(parameters.sequence_length), | ||
kv_sequence_length_(parameters.sequence_length), | ||
past_sequence_length_(parameters.seqlen_past_kv_cache), | ||
total_sequence_length_(parameters.total_sequence_length), | ||
hidden_size_(parameters.hidden_size), | ||
head_size_(parameters.head_size), | ||
v_hidden_size_(parameters.kv_hidden_size), | ||
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), | ||
num_heads_(parameters.num_heads), | ||
do_rotary_(parameters.do_rotary), | ||
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), | ||
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), | ||
kv_hidden_size_(parameters.kv_hidden_size), | ||
kv_num_heads_(parameters.kv_num_heads), | ||
num_splits_(parameters.num_splits), | ||
rotary_dim_(parameters.rotary_dim), | ||
is_packed_qkv_(parameters.is_packed_qkv), | ||
is_subsequent_prompt_(parameters.is_subsequent_prompt), | ||
is_first_prompt_(parameters.is_first_prompt), | ||
rotary_interleaved_(parameters.rotary_interleaved), | ||
use_smooth_softmax_(parameters.use_smooth_softmax), | ||
softcap_(parameters.softcap), | ||
zeros_count_(parameters.zeros_count), | ||
zero_ptr_(parameters.zero_ptr), | ||
n_reps(parameters.num_heads / parameters.kv_num_heads), | ||
qkv_format_(parameters.qkv_format) { | ||
} | ||
|
||
bool is_gqa_parameters_; | ||
int batch_size_ = 0; | ||
int sequence_length_ = 0; | ||
int kv_sequence_length_ = 0; // input sequence length of K or V | ||
int past_sequence_length_ = 0; // sequence length in past state of K or V | ||
int total_sequence_length_ = 0; // total sequence length of K or V | ||
int max_sequence_length_ = 0; // max sequence length from 4D mask | ||
int input_hidden_size_ = 0; // first dimension of weights for input projection | ||
int hidden_size_ = 0; // hidden size of Q or K | ||
int head_size_ = 0; // hidden size per head of Q or K | ||
int v_hidden_size_ = 0; // hidden size of V | ||
int v_head_size_ = 0; // hidden size per head of V | ||
int num_heads_ = 0; | ||
int rotary_embedding_ = 0; | ||
bool is_unidirectional_ = false; | ||
bool past_present_share_buffer_ = false; | ||
bool do_rotary_ = false; | ||
bool broadcast_attn_bias_dim_0_ = false; | ||
bool broadcast_attn_bias_dim_1_ = false; | ||
float mask_filter_value_ = -10000.0f; | ||
float scale_ = 0.0f; | ||
bool use_tf32_ = false; | ||
; | ||
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters | ||
// and not in onnxruntime::contrib::AttentionParameters | ||
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor | ||
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor | ||
int kv_hidden_size_ = 0; | ||
int kv_num_heads_ = 0; | ||
int num_splits_ = 0; // number of splits for splitkv | ||
int rotary_dim_ = 0; // rotary embedding dimension | ||
int local_window_size_ = 0; | ||
bool kv_share_buffer_ = false; | ||
bool is_packed_qkv_ = false; | ||
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 | ||
bool is_first_prompt_ = false; // indicates whether this is first decoding step | ||
bool rotary_interleaved_ = false; | ||
bool use_smooth_softmax_ = false; | ||
float softcap_ = 0.0; | ||
int zeros_count_ = 0; | ||
; | ||
int* zero_ptr_ = nullptr; | ||
// Computed values | ||
int n_reps = 1; | ||
AttentionMaskType mask_type_ = MASK_NONE; | ||
AttentionQkvFormat qkv_format_ = UNKNOWN; | ||
}; | ||
|
||
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, | ||
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); | ||
|
||
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, | ||
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, | ||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason WebGPU needs a parameters struct that combines AttentionParameters and GroupQueryAttentionParameters? Feels a little confusing to merge those and wondering why it's necessary if we don't need to do that for other EPs that implement these operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to avoid code duplication. I refactored code into attention used by both GQA and MHA. The CPU version has GQA separate implementation. group_query_attention_helper::CheckInputs() and AttentionBase::CheckInputs output different structs, GroupQueryAttentionParameters and AttentionParameters respectively. WebGPU parameters is a union of these to structs.