Skip to content

Commit

Permalink
FA implementation - doesnt work but should be all the code
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 22, 2024
1 parent 00461d1 commit e8bb833
Show file tree
Hide file tree
Showing 3 changed files with 417 additions and 1 deletion.
367 changes: 367 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,365 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
return Status::OK();
}

Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Expectations are
// qkv have same number of heads and hidden dimension (head size).
// qkv are in BSNH format.
// B - batch size but shader only supports batch_size 1.
// S - current sequence length but shader supports only S = 1.
// N - number of heads.
// H - head size or hidden dimension for each qkv head.
// KV cache is stored as BN(total_sequence_length)H
// Attention bias is in BN(total_sequence_length)
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (has_past_) {
shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
}
shader.AddOutput("present_key", ShaderUsage::UseUniform);
shader.AddOutput("present_value", ShaderUsage::UseUniform);

shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n"
<< "let kIdx = workgroup_id.x;\n"
<< "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n";
if (has_past_) {
shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n"
<< " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
<< " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n"
<< " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n"
<< " }\n"
<< "}\n"
<< "else if (kIdx >= uniforms.past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n";
}
shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n"
<< " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n"
<< " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n"
<< " // Assumes kv have BNSH layout.\n"
<< " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n"
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
<< " present_key[presentKeyOffset+w] = key[nOffset+w];\n"
<< " present_value[presentKeyOffset+w] = value[nOffset+w];\n"
<< " }\n"
<< "}\n";

return Status::OK();
}

Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters,
const Tensor* K, const Tensor* past_key, Tensor* present_key,
const Tensor* V, const Tensor* past_value, Tensor* present_value,
int past_sequence_length, int total_sequence_length) {

const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1);
bool has_past = (past_sequence_length != 0);
CopyKVCacheProgram program{"CopyKVCache", components, has_past};
if (has_past) {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_key, ProgramTensorMetadataDependency::TypeAndRank, components},
{past_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
} else {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
}

program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components},
{present_value, ProgramTensorMetadataDependency::Rank, components}});

program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads)
.SetWorkgroupSize(1)
.CacheHint(std::to_string(components) + std::to_string(has_past))
.AddUniformVariables({{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length)},
{static_cast<uint32_t>(parameters.head_size/ components)}});

return context.RunProgram(program);
}

Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Expectations are
// qkv have same number of heads and hidden dimension (head size).
// qkv are in BSNH format.
// B - batch size but shader only supports batch_size 1.
// S - current sequence length but shader supports only S = 1.
// N - number of heads.
// H - head size or hidden dimension for each qkv head.
// KV cache is stored as BN(total_sequence_length)H
// Attention bias is in BN(new_sequence_length)(total_sequence_length)
//
// Expectation is that present_key, and present_value contain past key and values since
// we are out of storage buffers a shader can have and both past/present cant be passed.
// The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads.
constexpr int vectorization_size = 4;
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("present_key", ShaderUsage::UseUniform);
shader.AddInput("present_value", ShaderUsage::UseUniform);
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);

// SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8.
// TILE_SIZE is the number of groups sharing the k_tile.
// TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when
// TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE
// because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu
// gpu limits. For Intel this TILE_SIZE will be 8.
shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n"
<< "const TILE_SIZE: u32 = " << tile_size_ << ";\n"
<< "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n"
<< "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n"
<< "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n"
<< "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n"
<< "const MIN_VALUE : q_element_t = -6504.0h;\n";

// Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake
// GPU afterwhich workgroups will be unscheduled to make space for memory.
shader.AdditionalImplementation() << "var<workgroup> q_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> k_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> v_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> o_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> qk_tile : array<array<q_element_t, TILE_SIZE>, TILE_SIZE>; // 8 * 2 * 8 = 128\n"
<< "var<workgroup> max_tile : array<q_element_t, TILE_SIZE>; // 2 * 8 = 16\n"
<< "var<workgroup> denom_tile : array<q_element_t, TILE_SIZE>; // 2 * 8 = 16\n"
<< "var<workgroup> o_ratio : array<q_element_t, TILE_SIZE>; // 2 * 8 = 16\n";

shader.AdditionalImplementation() << R"HELPER_FN(
fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
{
if (q_idx_global >= uniforms.new_sequence_length) {
return;
}
// Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA
// This is the layout if TransferBSDToBNSH has not been run.
let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx;
// Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run.
// let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size)
{
var value = q[idx+offset];
q_tile[slot][idx] = value;
}
}
fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32)
{
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size)
{
var value = present_key[idx+offset];
k_tile[slot][idx] = value;
}
}
fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32)
{
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size)
{
v_tile[slot][idx] = present_value[idx+offset];
}
}
fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32)
{
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) {
qk_tile[q_row][k_col] = 0.0;
return;
}
let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global;
qk_tile[q_row][k_col] = attention_bias[offset];
}
fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
{
if (o_idx_global >= uniforms.new_sequence_length) {
return;
}
// Stored as float16[batch_size,sequence_length,3072]
let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size)
{
let value = o_tile[slot][idx];
output[offset+idx] = value;
}
}
fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32)
{
var sum:vec4<q_element_t> = q_value_t(0, 0, 0, 0);
for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size)
{
var result = q_value_t(0);
let sg_idx = idx+sg_id;
// QKV_HEAD_VECTORIZED_SIZE is divisible by the subgroup size this if check is not
// required. Hopefully the compiler sees the first half of this if statement and
// removes this if instruction.
if (QKV_HEAD_VECTORIZED_SIZE % sg_size == 0 || sg_idx < QKV_HEAD_VECTORIZED_SIZE)
{
result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx];
}
sum += subgroupAdd(result);
}
if (sg_id == 0)
{
let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w;
let sqrt_dk = q_element_t(uniforms.alpha);
let value = single_sum * sqrt_dk;
qk_tile[q_idx][k_idx] += value;
}
}
fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool)
{
var x = MIN_VALUE;
if (enabled){
x = qk_tile[q_idx][sg_id];
}
var max_value = subgroupMax(x);
max_value = max(max_tile[q_idx], max_value);
let sub = x - max_value;
var value:q_element_t = 0;
if (enabled) {
value = exp(sub);
}
let sum = subgroupAdd(value);
// Compute lhs term of update di prime and the compute di prime.
let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value);
var d = dleft + sum;
if (d == 0)
{
d = 0.0000001h;
}
qk_tile[q_idx][sg_id] = value / d;
if (sg_id == 0)
{
max_tile[q_idx] = max_value;
denom_tile[q_idx] = d;
o_ratio[q_idx] = dleft / d;
}
}
fn computeO(q_idx: u32, sg_id:u32, enabled:bool)
{
var attn = q_element_t(0);
if (enabled)
{
attn = qk_tile[q_idx][sg_id];
}
for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++)
{
let val = v_tile[sg_id][i];
var intermediate = attn * val;
let sum = subgroupAdd(intermediate);
if (sg_id == 0)
{
let o_ratio = o_ratio[q_idx];
let old_o = o_tile[q_idx][i];
let new_o = ( o_ratio * old_o) + sum;
o_tile[q_idx][i] = new_o;
}
}
}
)HELPER_FN";

// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1)
shader.MainFunctionBody() << R"MAIN_FN(
let head_idx = workgroup_id.x;
let wave_x = (local_id.x / 4);
let wave_y = (local_id.y / 4);
let wave_id = wave_x + wave_y * 4;
let q_idx_start = workgroup_id.y * TILE_SIZE;
let q_idx_global = q_idx_start + wave_id;
// Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query.
loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size);
if (sg_id == 0)
{
max_tile[wave_id] = MIN_VALUE;
}
for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE)
{
let k_idx_global = k_start+wave_id;
let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length;
if (k_idx_global_using_wave_valid) {
// Leveraging the subgroup lanes for parallelism, load into slot wave_id
// K/V values from k_start+wave_id.
loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size);
loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size);
// Next, we want for every q row (wave_id) to populate bias for new sequence length
// (k_start+sg_id). loadAttentionBias handles range checking q_idx_global,
// and (k_start+sg_id).
loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
}
workgroupBarrier();
if (k_idx_global_using_wave_valid)
{
for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.present_sequence_length; q_idx++)
{
// Leveraging the subgroups for parallelism, compute dot product of QK.
// Because for the case of new_seq 1, there is a single query and context length of K
// we iterate over q and use the waves for K so that this step can use all the waves in
// in the workgroup.
computeDotProduct(q_idx, wave_id, sg_id, sg_size);
}
}
let k_idx_global_using_lane_valid:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(wave_id, sg_id, k_idx_global_using_lane_valid);
computeO(wave_id, sg_id, k_idx_global_using_lane_valid);
}
workgroupBarrier();
writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size);
)MAIN_FN";

return Status::OK();
}

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length));

constexpr int subgroup_size = 16;
constexpr int tile_size = 16;
bool has_attention_bias = attention_bias != nullptr;
FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4},
{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}});
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}});
const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size))
: parameters.scale;
std::string cache_hint = std::to_string(has_attention_bias) +
std::to_string(subgroup_size) +
std::to_string(tile_size) +
std::to_string(parameters.head_size) +
std::to_string(parameters.num_heads);
program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1)
.SetWorkgroupSize(subgroup_size, subgroup_size)
.CacheHint(cache_hint)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
{static_cast<uint32_t>(parameters.total_sequence_length)},
{alpha}});

return context.RunProgram(program);
}

MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
: WebGpuKernel(info) {
int64_t num_heads = 0;
Expand Down Expand Up @@ -457,6 +816,14 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_key = context.Output(1, present_shape);
Tensor* present_value = context.Output(2, present_shape);

if (parameters.batch_size == 1 &&
bias == nullptr &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}

TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
parameters.sequence_length, parameters.head_size});
TensorShape q_new_shape(q_new_dims);
Expand Down
Loading

0 comments on commit e8bb833

Please sign in to comment.