Skip to content

Commit

Permalink
Calculate output chunk size based on whether the kernel is GQA or not.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Nov 22, 2024
1 parent 5dc95c8 commit e448b1a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "}\n";

shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";

Expand Down Expand Up @@ -200,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand Down Expand Up @@ -416,8 +416,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
<< " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n"
<< " let outputIdx = batch_idx * uniforms.M * tmp + "
<< " m * tmp + head_idx * uniforms.N + n;\n"
<< " output[outputIdx] = value;\n"
<< "}\n";

Expand All @@ -438,7 +439,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const bool has_present_value = output_count > 1 && past_value != nullptr;
const int tile_size = 12;

Check warning

Code scanning / PREfast

The const variable 'tile_size' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'tile_size' can be computed at compile-time. Consider using constexpr (con.5).

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -63,6 +63,7 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
Expand All @@ -89,8 +90,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)

Check warning on line 93 in onnxruntime/contrib_ops/webgpu/bert/attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.h:93: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -116,6 +117,7 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

} // namespace webgpu
Expand Down

0 comments on commit e448b1a

Please sign in to comment.