Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/sajandhy/webgpu-ep-gqa-new' into…
Browse files Browse the repository at this point in the history
… gs/wgpu
  • Loading branch information
guschmue committed Nov 21, 2024
2 parents f6e1d44 + 5dc95c8 commit 92ca0d2
Show file tree
Hide file tree
Showing 9 changed files with 962 additions and 517 deletions.
38 changes: 20 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ namespace onnxruntime {
namespace contrib {
namespace group_query_attention_helper {

Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
Expand Down Expand Up @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query,
return Status::OK();
}

Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap,
int max_threads_per_block) {
Expand Down
500 changes: 500 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc

Large diffs are not rendered by default.

123 changes: 123 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
public:
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
{"batch_offset", ProgramUniformVariableDataType::Uint32},
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
{"head_offset", ProgramUniformVariableDataType::Uint32},
{"bias_offset", ProgramUniformVariableDataType::Uint32});

private:
bool has_bias_;
};

class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

private:
bool feed_past_key_;
bool has_present_key_;
bool has_attention_bias_;
int tile_size_;
int components_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});

private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
bool is_first_prompt_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

private:
bool feed_past_value_;
bool has_present_value_;
int tile_size_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
130 changes: 130 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"

#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace webgpu {

struct WebgpuAttentionParameters {
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false),
batch_size_(parameters.batch_size),
sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.kv_sequence_length),
past_sequence_length_(parameters.past_sequence_length),
total_sequence_length_(parameters.total_sequence_length),
max_sequence_length_(parameters.max_sequence_length),
input_hidden_size_(parameters.input_hidden_size),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.v_hidden_size),
v_head_size_(parameters.v_head_size),
num_heads_(parameters.num_heads),
is_unidirectional_(parameters.is_unidirectional),
past_present_share_buffer_(parameters.past_present_share_buffer),
do_rotary_(parameters.do_rotary),
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0),
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1),
mask_filter_value_(parameters.mask_filter_value),
scale_(parameters.scale),
mask_type_(parameters.mask_type),
qkv_format_(parameters.qkv_format) {
}

WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true),
batch_size_(parameters.batch_size),
sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.sequence_length),
past_sequence_length_(parameters.seqlen_past_kv_cache),
total_sequence_length_(parameters.total_sequence_length),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.kv_hidden_size),
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads),
num_heads_(parameters.num_heads),
do_rotary_(parameters.do_rotary),
scale_(parameters.scale),
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache),
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache),
kv_hidden_size_(parameters.kv_hidden_size),
kv_num_heads_(parameters.kv_num_heads),
num_splits_(parameters.num_splits),
rotary_dim_(parameters.rotary_dim),
is_packed_qkv_(parameters.is_packed_qkv),
is_subsequent_prompt_(parameters.is_subsequent_prompt),
is_first_prompt_(parameters.is_first_prompt),
rotary_interleaved_(parameters.rotary_interleaved),
use_smooth_softmax_(parameters.use_smooth_softmax),
softcap_(parameters.softcap),
zeros_count_(parameters.zeros_count),
zero_ptr_(parameters.zero_ptr),
n_reps(parameters.num_heads / parameters.kv_num_heads),
qkv_format_(parameters.qkv_format) {
}

bool is_gqa_;
int batch_size_ = 0;
int sequence_length_ = 0;
int kv_sequence_length_ = 0; // input sequence length of K or V
int past_sequence_length_ = 0; // sequence length in past state of K or V
int total_sequence_length_ = 0; // total sequence length of K or V
int max_sequence_length_ = 0; // max sequence length from 4D mask
int input_hidden_size_ = 0; // first dimension of weights for input projection
int hidden_size_ = 0; // hidden size of Q or K
int head_size_ = 0; // hidden size per head of Q or K
int v_hidden_size_ = 0; // hidden size of V
int v_head_size_ = 0; // hidden size per head of V
int num_heads_ = 0;
int rotary_embedding_ = 0;
bool is_unidirectional_ = false;
bool past_present_share_buffer_ = false;
bool do_rotary_ = false;
bool broadcast_attn_bias_dim_0_ = false;
bool broadcast_attn_bias_dim_1_ = false;
float mask_filter_value_ = -10000.0f;
float scale_ = 0.0f;
bool use_tf32_ = false;
;
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters
// and not in onnxruntime::contrib::AttentionParameters
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor
int kv_hidden_size_ = 0;
int kv_num_heads_ = 0;
int num_splits_ = 0; // number of splits for splitkv
int rotary_dim_ = 0; // rotary embedding dimension
int local_window_size_ = 0;
bool kv_share_buffer_ = false;
bool is_packed_qkv_ = false;
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1
bool is_first_prompt_ = false; // indicates whether this is first decoding step
bool rotary_interleaved_ = false;
bool use_smooth_softmax_ = false;
float softcap_ = 0.0;
int zeros_count_ = 0;
;
int* zero_ptr_ = nullptr;
// Computed values
int n_reps = 1;
AttentionMaskType mask_type_ = MASK_NONE;
AttentionQkvFormat qkv_format_ = UNKNOWN;
};

Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 92ca0d2

Please sign in to comment.