Skip to content

Commit

Permalink
Use total_seqlen_tensor input only to determin is_first_prompt.
Browse files Browse the repository at this point in the history
  • Loading branch information
SatyaJandhyalaAtMS committed Nov 4, 2024
1 parent ce7a583 commit 514217f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 66 deletions.
86 changes: 36 additions & 50 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,15 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};

void InitVarStub(const Tensor* seqlens_k, const Tensor* total_seqlen_tensor, bool init_past_sequence_length, std::ostringstream& ss) {
if (seqlens_k != nullptr && total_seqlen_tensor != nullptr) {
ss << "let total_sequence_length_input = u32(total_seqlen_tensor[0]);\n";
ss << "let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);\n";
ss << "let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;\n";
ss << "let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;\n";
void InitVarStub(std::ostringstream& ss, const Tensor* seqlens_k) {
if (seqlens_k != nullptr) {
ss << "total_sequence_length = u32(seqlens_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = 0;\n";
ss << "if (is_first_prompt == false) {\n";
ss << "if (uniforms.is_first_prompt != 0) {\n";
ss << " past_sequence_length = total_sequence_length - sequence_length;\n";
ss << "}\n";
} else {
if (init_past_sequence_length) {
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
}
ss << "let present_sequence_length = total_sequence_length;\n";
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
}
}

Expand All @@ -99,9 +92,6 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (seqlen_k_ != nullptr) {
shader.AddInput("seqlens_k", ShaderUsage::UseUniform);
}
if (total_seqlen_tensor_ != nullptr) {
shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (has_present_key_) {
shader.AddOutput("present_key", ShaderUsage::UseUniform);
Expand All @@ -118,7 +108,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
if (n_reps_ > 1) {
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
Expand Down Expand Up @@ -163,7 +153,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

if (has_present_key_) {
shader.MainFunctionBody() << " if (n + local_id.y < present_sequence_length) {\n"
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"
<< " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"
<< " }\n";
}
Expand Down Expand Up @@ -194,7 +184,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q,
const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key,
WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length,
const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) {
const Tensor* seqlen_k) {
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;

Expand All @@ -205,7 +195,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.n_reps, seqlen_k, total_seqlen_tensor};
components, parameters.n_reps, seqlen_k};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand All @@ -214,9 +204,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) {
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}});
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}});
if (has_present_key) {
Expand All @@ -228,7 +217,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size))
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
Expand All @@ -237,7 +226,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
{static_cast<float>(alpha)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(parameters.n_reps)}})
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand All @@ -247,9 +238,6 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (seqlen_k_) {
shader.AddInput("seqlens_k", ShaderUsage::UseUniform);
}
if (total_seqlen_tensor_) {
shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform);
}
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
Expand All @@ -259,7 +247,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.sequence_length;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
Expand Down Expand Up @@ -304,7 +292,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) {
const Tensor* seqlen_k, bool is_first_prompt) {
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
int work_group_size = 64;
const int total_sequence_length_comp = total_sequence_length / components;
Expand All @@ -313,20 +301,21 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;

InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k, total_seqlen_tensor};
if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) {
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}});
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(std::to_string(work_group_size), is_first_prompt)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(sequence_length)},
{static_cast<uint32_t>(total_sequence_length_comp)},
{static_cast<uint32_t>(elementsPerThread)}});
{static_cast<uint32_t>(elementsPerThread)},
{static_cast<uint32_t>(is_first_prompt)}});

return context.RunProgram(program);
}
Expand All @@ -340,9 +329,6 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (seqlen_k_) {
shader.AddInput("seqlens_k", ShaderUsage::UseUniform);
}
if (total_seqlen_tensor_) {
shader.AddInput("total_seqlen_tensor", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);
if (has_present_value_) {
shader.AddOutput("present_value", ShaderUsage::UseUniform);
Expand All @@ -358,7 +344,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(seqlen_k_, total_seqlen_tensor_, true, oss);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
if (n_reps_ > 1) {
shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n"
Expand Down Expand Up @@ -404,7 +390,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

if (has_present_value_) {
shader.MainFunctionBody() << " if (w + local_id.y < present_sequence_length) {\n"
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"
<< " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"
<< " }\n";
}
Expand Down Expand Up @@ -436,21 +422,19 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
WebgpuAttentionParameters& parameters,
int past_sequence_length,
int total_sequence_length,
const Tensor* seqlen_k,
const Tensor* total_seqlen_tensor) {
const Tensor* seqlen_k) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0;
const bool has_present_value = output_count > 1 && past_value != nullptr;
const int tile_size = 12;

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, total_seqlen_tensor};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank});
}
if (seqlen_k != nullptr && total_seqlen_tensor != nullptr) {
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{total_seqlen_tensor, ProgramTensorMetadataDependency::TypeAndRank}});
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}});
if (has_present_value) {
Expand All @@ -460,6 +444,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
Expand All @@ -469,16 +454,17 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
{static_cast<uint32_t>(parameters.v_hidden_size_)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(parameters.n_reps)}})
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
;

return context.RunProgram(program);
}

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, const Tensor* total_seqlen_tensor) {
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
Expand All @@ -488,13 +474,13 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
const TensorShape probs_shape(probs_dims);
Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape);
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor));
parameters, past_sequence_length, total_sequence_length, seqlen_k));

ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, total_seqlen_tensor));
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));

ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k, total_seqlen_tensor));
parameters, past_sequence_length, total_sequence_length, seqlen_k));

return Status::OK();
}
Expand Down
Loading

0 comments on commit 514217f

Please sign in to comment.