Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu EP] Binary operators #22112

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ ONNX_OPERATOR_KERNEL_EX(
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);

std::string add_bias = "";
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride);
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n"
" x += input_value_t(" +
bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " +
Expand Down
311 changes: 311 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram> {
public:
BinaryElementwiseProgram(const std::string& kernel_name,
const std::string& expression,
const bool is_broadcast,
const bool is_lhs_scalar,
const bool is_rhs_scalar,
const bool vectorize) : Program{kernel_name},
expression_{expression},
is_broadcast_{is_broadcast},
is_lhs_scalar_{is_lhs_scalar},
is_rhs_scalar_{is_rhs_scalar},
vectorize_{vectorize} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
std::string expression_;
bool is_broadcast_;
bool is_lhs_scalar_;
bool is_rhs_scalar_;
bool vectorize_;
};

class BinaryElementwise : public WebGpuKernel {
public:
BinaryElementwise(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression} {}

protected:
Status ComputeInternal(ComputeContext& context) const final;

private:
std::string kernel_name_;
std::string expression_;

Check warning on line 52 in onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h:52: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace webgpu
} // namespace onnxruntime
18 changes: 9 additions & 9 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
namespace onnxruntime {
namespace webgpu {
Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_);
const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform);
shader.AppendImplementation(additional_impl_);
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" let a = ", input.GetByOffset("global_idx"), ";\n ",
Expand Down Expand Up @@ -98,7 +98,7 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)")
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes())

Expand All @@ -113,7 +113,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes())
class HardSigmoid final : public UnaryElementwise {
public:
HardSigmoid(const OpKernelInfo& info)
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} {
// attr[0] is alpha, attr[1] is beta
info.GetAttrOrDefault("alpha", attr, 0.2f);
info.GetAttrOrDefault("beta", attr + 1, 0.5f);
Expand Down Expand Up @@ -154,7 +154,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)")
WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes())

Expand All @@ -180,7 +180,7 @@ class Clip final : public UnaryElementwise {
: UnaryElementwise{info,
"Clip",
std::is_same_v<T, MLFloat16> ? ClipF16Impl : ClipImpl,
"", ShaderVariable::UseElementTypeAlias} {}
"", ShaderUsage::UseElementTypeAlias} {}

Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override {
const auto* clip_min_tensor = context.Input<Tensor>(1);
Expand Down Expand Up @@ -240,7 +240,7 @@ class LinearUnit : public UnaryElementwise {
const std::string& expression,
const std::string& additional_impl,
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} {
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Expand Down Expand Up @@ -269,14 +269,14 @@ class Gelu : public UnaryElementwise {
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderVariable::UseValueTypeAlias} {
ShaderUsage::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
}
};

WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes())
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class UnaryElementwiseProgram final : public Program<UnaryElementwiseProgram> {
public:
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage)
UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage)

Check warning on line 15 in onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h:15: Lines should be <= 120 characters long [whitespace/line_length] [2]
: Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} {
}

Expand All @@ -26,7 +26,7 @@
private:
std::string_view expression_;
std::string_view additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch
Expand All @@ -38,11 +38,11 @@
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl = "",
ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}
ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}

protected:
std::string cache_hint;
Expand All @@ -57,7 +57,7 @@
std::string kernel_name_;
std::string expression_;
std::string additional_impl_;
ShaderVariable::Usage additional_usage_;
ShaderUsage additional_usage_;
};

constexpr const char ErfImpl[] = R"(
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,16 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list<ProgramOutput> output
return *this;
}

ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) {
return SetDispatchGroupSize(x, 1, 1);
}
Expand Down Expand Up @@ -309,4 +319,4 @@ ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list<ProgramO
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
18 changes: 6 additions & 12 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,6 @@ inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependen

constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;

// represents the scope of a variable in a shader program.
//
// this is not a full list of all possible variable scopes in shader programs.
// it only includes what are used in WebGPU EP.
enum class ProgramVariableScope {
Input = 0, // storage buffer variable with access mode "read"
Output = 1, // storage buffer variable with access mode "read_write"
Local = 2, // local variable

Count // should always be the last element
};

// data type of variable
//
// this is not a full list of all possible data types in shader programs.
Expand Down Expand Up @@ -265,6 +253,10 @@ class ProgramBase {
ProgramBase& AddOutput(ProgramOutput&& output);
// add multiple program outputs
ProgramBase& AddOutputs(std::initializer_list<ProgramOutput> outputs);
// add a program variable for indices
ProgramBase& AddIndices(const TensorShape& shape);
// add a program variable for indices
ProgramBase& AddIndices(TensorShape&& shape);

// set the size of dispatch groups. Y and Z are 1 if not specified.
ProgramBase& SetDispatchGroupSize(uint32_t x);
Expand Down Expand Up @@ -330,6 +322,7 @@ class ProgramBase {
inline const std::string& CacheHint() const { return cache_hint_; }
inline const std::vector<ProgramInput>& Inputs() const { return inputs_; }
inline const std::vector<ProgramOutput>& Outputs() const { return outputs_; }
inline const std::vector<TensorShape>& Indices() const { return indices_; }
inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; }
inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; }
inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; }
Expand All @@ -351,6 +344,7 @@ class ProgramBase {
std::string cache_hint_;
std::vector<ProgramInput> inputs_;
std::vector<ProgramOutput> outputs_;
std::vector<TensorShape> indices_;

uint32_t dispatch_group_size_x_;
uint32_t dispatch_group_size_y_;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/program_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ Status ProgramManager::Build(const ProgramBase& program,

ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper));

ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices());

// code is a large std::string that contains the final shader code
std::string code;
Expand Down
Loading
Loading