Skip to content

Commit

Permalink
FA works onn intel (TILE_SIZE == SUBGROUP_SIZE) for seq length of 1.
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 21, 2024
1 parent 80296aa commit c281f84
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 35 deletions.
102 changes: 69 additions & 33 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,23 +421,29 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Attention bias is in BN(total_sequence_length)
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (has_past_) {
shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
}
shader.AddOutput("present_key", ShaderUsage::UseUniform);
shader.AddOutput("present_value", ShaderUsage::UseUniform);

shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n"
<< "let kIdx = workgroup_id.x;\n"
<< "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"
<< "if (kIdx < uniforms.past_sequence_length) {\n"
<< " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
<< " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n"
<< " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n"
<< " }\n"
<< "}\n"
<< "else if (kIdx >= uniforms.past_sequence_length) {\n"
<< " let nkIdx = kIdx - uniforms.past_sequence_length;\n"
<< "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n";
if (has_past_) {
shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n"
<< " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
<< " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n"
<< " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n"
<< " }\n"
<< "}\n"
<< "else if (kIdx >= uniforms.past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n";
}
shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n"
<< " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n"
<< " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n"
<< " // Assumes kv have BNSH layout.\n"
Expand All @@ -457,17 +463,24 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParame
int past_sequence_length, int total_sequence_length) {

const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1);
CopyKVCacheProgram program{"CopyKVCache", components};
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_key, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
bool has_past = (past_sequence_length != 0);
CopyKVCacheProgram program{"CopyKVCache", components, has_past};
if (has_past) {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_key, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
} else {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
}

program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components},
{present_value, ProgramTensorMetadataDependency::Rank, components}});

program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads)
.SetWorkgroupSize(1)
.CacheHint(std::to_string(components))
.CacheHint(std::to_string(components) + std::to_string(has_past))
.AddUniformVariables({{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length)},
{static_cast<uint32_t>(parameters.head_size/ components)}});
Expand Down Expand Up @@ -669,15 +682,15 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool)
fn computeO(q_idx: u32, sg_id:u32, enabled:bool)
{
var attn = q_element_t(0);
if (enabled)
{
attn = qk_tile[q_idx][sg_id];
}
for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++)
{
let attn = qk_tile[q_idx][sg_id];
let val = v_tile[sg_id][i];
var intermediate = q_value_t(0);
if (enabled)
{
intermediate = attn * val;
}
var intermediate = attn * val;
let sum = subgroupAdd(intermediate);
if (sg_id == 0)
{
Expand Down Expand Up @@ -733,10 +746,33 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length));
//return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key,
// present_value, parameters, context, true);

constexpr int subgroup_size = 32;
// // Uncomment to test CopyKVCache independent of FlashAttentionProgram.
// TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.sequence_length, parameters.head_size});
// TensorShape q_new_shape(q_new_dims);
// Tensor Qn = context.CreateGPUTensor(Q->DataType(), q_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
// context, parameters.num_heads, parameters.sequence_length, parameters.head_size, Q, nullptr, 0, &Qn));

// TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.kv_sequence_length, parameters.head_size});
// TensorShape k_new_shape(k_new_dims);
// Tensor Kn = context.CreateGPUTensor(K->DataType(), k_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
// parameters.head_size, K, nullptr, parameters.hidden_size, &Kn));

// TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.kv_sequence_length, parameters.v_head_size});
// TensorShape v_new_shape(v_new_dims);
// Tensor Vn = context.CreateGPUTensor(V->DataType(), v_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
// parameters.v_head_size, V, nullptr, 2 * parameters.hidden_size, &Vn));

// return ApplyAttention(&Qn, &Kn, &Vn, attention_bias, past_key, past_value, output, present_key,
// present_value, parameters, context, true);

constexpr int subgroup_size = 8;
constexpr int tile_size = 8;
bool has_attention_bias = attention_bias != nullptr;
FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads};
Expand Down Expand Up @@ -817,12 +853,12 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_value = context.Output(2, present_shape);

if (parameters.batch_size == 1 &&
bias == nullptr &&
past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
bias == nullptr &&
past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}

TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, int components)
: Program{kernel_name}, components_(components) {
CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past)
: Program{kernel_name}, components_(components), has_past_(has_past) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -113,6 +113,7 @@ class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {

private:
int components_;
bool has_past_;
};

class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
Expand Down

0 comments on commit c281f84

Please sign in to comment.