Skip to content

Commit

Permalink
Removed is_first_prompt from uniforms, used in a condition generating…
Browse files Browse the repository at this point in the history
… shader code and added to hint.
  • Loading branch information
satyajandhyala committed Nov 20, 2024
1 parent a48d782 commit 4334b39
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
44 changes: 19 additions & 25 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};

void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) {
if (seqlen_k != nullptr) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = 0;\n";
ss << "if (uniforms.is_first_prompt != 0) {\n";
ss << " past_sequence_length = total_sequence_length - sequence_length;\n";
ss << "}\n";
ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n";
} else {
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
}
Expand Down Expand Up @@ -108,7 +105,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str();
if (n_reps_ > 1) {
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
Expand All @@ -122,7 +119,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n";
shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n";
}
} else {
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n";
Expand All @@ -132,7 +129,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n";
shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n";
}
}

Expand Down Expand Up @@ -203,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand All @@ -225,7 +222,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
Expand All @@ -235,8 +232,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
{static_cast<uint32_t>(parameters.n_reps)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand All @@ -255,7 +251,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.sequence_length;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
Expand Down Expand Up @@ -309,21 +305,20 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;

InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(std::to_string(work_group_size))
.CacheHint(work_group_size, is_first_prompt)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(sequence_length)},
{static_cast<uint32_t>(total_sequence_length_comp)},
{static_cast<uint32_t>(elementsPerThread)},
{static_cast<uint32_t>(is_first_prompt)}});
{static_cast<uint32_t>(elementsPerThread)}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -352,7 +347,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str();
if (n_reps_ > 1) {
shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n"
Expand All @@ -366,7 +361,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n";
shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n";
}
} else {
shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n";
Expand All @@ -377,7 +372,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.K + n;\n";
shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n";
}
}

Expand Down Expand Up @@ -407,7 +402,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
} else {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
}
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.present_sequence_length] = tileK[idx];\n"
<< " }\n";
}

Expand Down Expand Up @@ -443,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const bool has_present_value = output_count > 1 && past_value != nullptr;
const int tile_size = 12;

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
Expand All @@ -460,7 +455,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
Expand All @@ -471,8 +466,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
{static_cast<uint32_t>(parameters.n_reps)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -49,8 +49,7 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"n_reps", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand All @@ -63,12 +62,13 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -78,19 +78,19 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});

private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
bool is_first_prompt_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -104,8 +104,7 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"n_reps", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand All @@ -116,6 +115,7 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

} // namespace webgpu
Expand Down

0 comments on commit 4334b39

Please sign in to comment.