Skip to content

Commit

Permalink
[webgpu EP] Binary operators (#22112)
Browse files Browse the repository at this point in the history
based on:
- #22058

---------

Co-authored-by: Qin Jiajia <[email protected]>
  • Loading branch information
fs-eire and qjia7 authored Sep 17, 2024
1 parent c5cf2ab commit 0bc714f
Show file tree
Hide file tree
Showing 17 changed files with 736 additions and 262 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ ONNX_OPERATOR_KERNEL_EX(
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);

std::string add_bias = "";

Check warning on line 26 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:26: Add #include <string> for string [build/include_what_you_use] [4]
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride);
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n"
" x += input_value_t(" +
bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " +
Expand Down
311 changes: 311 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram> {
public:
BinaryElementwiseProgram(const std::string& kernel_name,
const std::string& expression,
const bool is_broadcast,
const bool is_lhs_scalar,
const bool is_rhs_scalar,
const bool vectorize) : Program{kernel_name},
expression_{expression},
is_broadcast_{is_broadcast},
is_lhs_scalar_{is_lhs_scalar},
is_rhs_scalar_{is_rhs_scalar},
vectorize_{vectorize} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
std::string expression_;
bool is_broadcast_;
bool is_lhs_scalar_;
bool is_rhs_scalar_;
bool vectorize_;
};

class BinaryElementwise : public WebGpuKernel {
public:
BinaryElementwise(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression} {}

protected:
Status ComputeInternal(ComputeContext& context) const final;

private:
std::string kernel_name_;
std::string expression_;
};

} // namespace webgpu
} // namespace onnxruntime
18 changes: 9 additions & 9 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
namespace onnxruntime {
namespace webgpu {
Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform);
shader.AppendImplementation(additional_impl_);
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" let a = ", input.GetByOffset("global_idx"), ";\n ",
Expand Down Expand Up @@ -98,7 +98,7 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)")
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes())

Expand All @@ -113,7 +113,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes())
class HardSigmoid final : public UnaryElementwise {
public:
HardSigmoid(const OpKernelInfo& info)
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} {
// attr[0] is alpha, attr[1] is beta
info.GetAttrOrDefault("alpha", attr, 0.2f);
info.GetAttrOrDefault("beta", attr + 1, 0.5f);
Expand Down Expand Up @@ -154,7 +154,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)")
WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes())

Expand All @@ -180,7 +180,7 @@ class Clip final : public UnaryElementwise {
: UnaryElementwise{info,
"Clip",
std::is_same_v<T, MLFloat16> ? ClipF16Impl : ClipImpl,
"", ShaderVariable::UseElementTypeAlias} {}
"", ShaderUsage::UseElementTypeAlias} {}

Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override {
const auto* clip_min_tensor = context.Input<Tensor>(1);
Expand Down Expand Up @@ -240,7 +240,7 @@ class LinearUnit : public UnaryElementwise {
const std::string& expression,
const std::string& additional_impl,
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Expand Down Expand Up @@ -269,14 +269,14 @@ class Gelu : public UnaryElementwise {
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderVariable::UseValueTypeAlias} {
ShaderUsage::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
}
};

WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes())
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace webgpu {

class UnaryElementwiseProgram final : public Program<UnaryElementwiseProgram> {
public:
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage)
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage)
: Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} {
}

Expand All @@ -26,7 +26,7 @@ class UnaryElementwiseProgram final : public Program<UnaryElementwiseProgram> {
private:
std::string_view expression_;
std::string_view additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch
Expand All @@ -38,11 +38,11 @@ class UnaryElementwise : public WebGpuKernel {
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl = "",
ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}
ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}

protected:
std::string cache_hint;
Expand All @@ -57,7 +57,7 @@ class UnaryElementwise : public WebGpuKernel {
std::string kernel_name_;
std::string expression_;
std::string additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

constexpr const char ErfImpl[] = R"(
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,16 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list<ProgramOutput> output
return *this;
}

ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) {
return SetDispatchGroupSize(x, 1, 1);
}
Expand Down Expand Up @@ -309,4 +319,4 @@ ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list<ProgramO
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
18 changes: 6 additions & 12 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,6 @@ inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependen

constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;

// represents the scope of a variable in a shader program.
//
// this is not a full list of all possible variable scopes in shader programs.
// it only includes what are used in WebGPU EP.
enum class ProgramVariableScope {
Input = 0, // storage buffer variable with access mode "read"
Output = 1, // storage buffer variable with access mode "read_write"
Local = 2, // local variable

Count // should always be the last element
};

// data type of variable
//
// this is not a full list of all possible data types in shader programs.
Expand Down Expand Up @@ -265,6 +253,10 @@ class ProgramBase {
ProgramBase& AddOutput(ProgramOutput&& output);
// add multiple program outputs
ProgramBase& AddOutputs(std::initializer_list<ProgramOutput> outputs);
// add a program variable for indices
ProgramBase& AddIndices(const TensorShape& shape);
// add a program variable for indices
ProgramBase& AddIndices(TensorShape&& shape);

// set the size of dispatch groups. Y and Z are 1 if not specified.
ProgramBase& SetDispatchGroupSize(uint32_t x);
Expand Down Expand Up @@ -330,6 +322,7 @@ class ProgramBase {
inline const std::string& CacheHint() const { return cache_hint_; }
inline const std::vector<ProgramInput>& Inputs() const { return inputs_; }
inline const std::vector<ProgramOutput>& Outputs() const { return outputs_; }
inline const std::vector<TensorShape>& Indices() const { return indices_; }
inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; }
inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; }
inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; }
Expand All @@ -351,6 +344,7 @@ class ProgramBase {
std::string cache_hint_;
std::vector<ProgramInput> inputs_;
std::vector<ProgramOutput> outputs_;
std::vector<TensorShape> indices_;

uint32_t dispatch_group_size_x_;
uint32_t dispatch_group_size_y_;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/program_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ Status ProgramManager::Build(const ProgramBase& program,

ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper));

ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices());

// code is a large std::string that contains the final shader code
std::string code;
Expand Down
Loading

0 comments on commit 0bc714f

Please sign in to comment.