Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 16, 2024
1 parent c5cf2ab commit 0e556ef
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 208 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ ONNX_OPERATOR_KERNEL_EX(
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);

std::string add_bias = "";
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride);
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n"
" x += input_value_t(" +
bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " +
Expand Down
18 changes: 9 additions & 9 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
namespace onnxruntime {
namespace webgpu {
Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform);
shader.AppendImplementation(additional_impl_);
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" let a = ", input.GetByOffset("global_idx"), ";\n ",
Expand Down Expand Up @@ -98,7 +98,7 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)")
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes())

Expand All @@ -113,7 +113,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes())
class HardSigmoid final : public UnaryElementwise {
public:
HardSigmoid(const OpKernelInfo& info)
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} {
// attr[0] is alpha, attr[1] is beta
info.GetAttrOrDefault("alpha", attr, 0.2f);
info.GetAttrOrDefault("beta", attr + 1, 0.5f);
Expand Down Expand Up @@ -154,7 +154,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)")
WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes())

Expand All @@ -180,7 +180,7 @@ class Clip final : public UnaryElementwise {
: UnaryElementwise{info,
"Clip",
std::is_same_v<T, MLFloat16> ? ClipF16Impl : ClipImpl,
"", ShaderVariable::UseElementTypeAlias} {}
"", ShaderUsage::UseElementTypeAlias} {}

Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override {
const auto* clip_min_tensor = context.Input<Tensor>(1);
Expand Down Expand Up @@ -240,7 +240,7 @@ class LinearUnit : public UnaryElementwise {
const std::string& expression,
const std::string& additional_impl,
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Expand Down Expand Up @@ -269,14 +269,14 @@ class Gelu : public UnaryElementwise {
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderVariable::UseValueTypeAlias} {
ShaderUsage::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
}
};

WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes())
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace webgpu {

class UnaryElementwiseProgram final : public Program<UnaryElementwiseProgram> {
public:
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage)
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage)
: Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} {
}

Expand All @@ -26,7 +26,7 @@ class UnaryElementwiseProgram final : public Program<UnaryElementwiseProgram> {
private:
std::string_view expression_;
std::string_view additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch
Expand All @@ -38,11 +38,11 @@ class UnaryElementwise : public WebGpuKernel {
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl = "",
ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}
ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}

protected:
std::string cache_hint;
Expand All @@ -57,7 +57,7 @@ class UnaryElementwise : public WebGpuKernel {
std::string kernel_name_;
std::string expression_;
std::string additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

constexpr const char ErfImpl[] = R"(
Expand Down
12 changes: 0 additions & 12 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,6 @@ inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependen

constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;

// represents the scope of a variable in a shader program.
//
// this is not a full list of all possible variable scopes in shader programs.
// it only includes what are used in WebGPU EP.
enum class ProgramVariableScope {
Input = 0, // storage buffer variable with access mode "read"
Output = 1, // storage buffer variable with access mode "read_write"
Local = 2, // local variable

Count // should always be the last element
};

// data type of variable
//
// this is not a full list of all possible data types in shader programs.
Expand Down
Loading

0 comments on commit 0e556ef

Please sign in to comment.