Skip to content

Commit

Permalink
Revert "Calculate output chunk size based on whether the kernel is GQ…
Browse files Browse the repository at this point in the history
…A or not."

This reverts commit e448b1a.
  • Loading branch information
satyajandhyala committed Nov 23, 2024
1 parent e448b1a commit 60af2f5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
11 changes: 5 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "}\n";

shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
<< " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n"
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";

Expand Down Expand Up @@ -200,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand Down Expand Up @@ -416,9 +416,8 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n"
<< " let outputIdx = batch_idx * uniforms.M * tmp + "
<< " m * tmp + head_idx * uniforms.N + n;\n"
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
<< " output[outputIdx] = value;\n"
<< "}\n";

Expand All @@ -439,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const bool has_present_value = output_count > 1 && past_value != nullptr;
const int tile_size = 12;

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
Expand All @@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_gqa_;
};

} // namespace webgpu
Expand Down

0 comments on commit 60af2f5

Please sign in to comment.