diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 7d8bef1e66f42..50debe26ce456 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -20,12 +20,12 @@ ONNX_OPERATOR_KERNEL_EX( FastGelu); Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); std::string add_bias = ""; if (Inputs().size() > 1) { - const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride); + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" " x += input_value_t(" + bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc new file mode 100644 index 0000000000000..9d9eff2ccdde5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webgpu/math/binary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + std::string common; + std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" + : "let a = " + a.GetByOffset("global_idx") + ";\n"; + std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" + : "let b = " + b.GetByOffset("global_idx") + ";\n"; + // check whether can use element-wise mode. + // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. + // In element-wise mode, no indices calculation is needed. + if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) { + const auto& c_indices = shader.AddIndices("bcast_indices"); + // check whether can use vectorize mode. + // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode + // can be enabled. + // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values. + // Use indices helpers to calculate the offset of A and B. + if (vectorize_) { + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + common = "let outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + + ";\n" + "let offset_a = " + + a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b = " + + b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; + get_a_data = a.NumComponents() == 4 ? "let a = " + a.GetByOffset("offset_a / 4") + ";\n" + : "let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n"; + get_b_data = b.NumComponents() == 4 ? "let b = " + b.GetByOffset("offset_b / 4") + ";\n" + : "let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n"; + } else { + // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. + common = "var outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + + ";\n" + "let offset_a0 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b0 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 1") + + ";\n" + "let offset_a1 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b1 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 2") + + ";\n" + "let offset_a2 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b2 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 3") + + ";\n" + "let offset_a3 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b3 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; + get_a_data = "let a = vec4(" + a.GetByOffset("offset_a0") + ", " + + a.GetByOffset("offset_a1") + ", " + + a.GetByOffset("offset_a2") + ", " + + a.GetByOffset("offset_a3") + ");\n"; + get_b_data = "let b = vec4(" + b.GetByOffset("offset_b0") + ", " + + b.GetByOffset("offset_b1") + ", " + + b.GetByOffset("offset_b2") + ", " + + b.GetByOffset("offset_b3") + ");\n"; + } + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + common, get_a_data, get_b_data, + c.SetByOffset("global_idx", expression_)); + return Status::OK(); +} + +Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { + auto lhs_tensor = context.Input(0); + auto rhs_tensor = context.Input(1); + const auto& lhs_shape = lhs_tensor->Shape(); + const auto& rhs_shape = rhs_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + auto output_tensor = context.Output(0, output_shape); + int64_t size = output_shape.Size(); + if (size == 0) { + return Status::OK(); + } + + bool is_broadcast = lhs_shape != rhs_shape; + bool is_lhs_scalar = lhs_shape.IsScalar(); + bool is_rhs_scalar = rhs_shape.IsScalar(); + + bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; + bool a_last_dim_divisible_by_4 = false; + bool b_last_dim_divisible_by_4 = false; + bool shared_dimension_divisible_by_4 = false; + size_t num_shared_dimension = 0; + if (!vectorize) { + // check whether vectorize can be enabled + a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; + b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; + if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { + vectorize = true; + } else { + size_t shared_dimension = 1; + for (size_t i = 1; i < output_shape.NumDimensions(); i++) { + size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + if (dimA == dimB) { + shared_dimension *= dimA; + num_shared_dimension++; + } else { + break; + } + } + if (shared_dimension % 4 == 0) { + shared_dimension_divisible_by_4 = true; + vectorize = true; + } + } + } + + SafeInt vec_size = (size + 3) / 4; + BinaryElementwiseProgram program{kernel_name_, + expression_, + is_broadcast, + is_lhs_scalar, + is_rhs_scalar, + vectorize}; + program + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}); + + if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { + // Mode Element-wise + // cache hint: "E{is_a_scalar}{is_b_scalar}" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4}, + {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}}) + .CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar)); + } else if (vectorize) { + // reshape the dims to merge the shared dimension if available + bool need_reshape = shared_dimension_divisible_by_4 && num_shared_dimension > 1; + TensorShape reshaped_lhs_shape = need_reshape ? lhs_shape.Slice(0, lhs_shape.NumDimensions() - num_shared_dimension + 1) + : lhs_shape; + TensorShape reshaped_rhs_shape = need_reshape ? rhs_shape.Slice(0, rhs_shape.NumDimensions() - num_shared_dimension + 1) + : rhs_shape; + TensorShape reshaped_output_shape = need_reshape ? output_shape.Slice(0, output_shape.NumDimensions() - num_shared_dimension + 1) + : output_shape; + if (need_reshape) { + reshaped_lhs_shape[reshaped_lhs_shape.NumDimensions() - 1] = lhs_shape.SizeFromDimension(lhs_shape.NumDimensions() - num_shared_dimension); + reshaped_rhs_shape[reshaped_rhs_shape.NumDimensions() - 1] = rhs_shape.SizeFromDimension(rhs_shape.NumDimensions() - num_shared_dimension); + reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension); + } + + if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type}); + } + if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type}); + } + // Mode Vectorize broadcast + // cache hint: "V{a_rank};{b_rank};{output_rank}" + program + .AddIndices(reshaped_output_shape) + .AddIndices(reshaped_lhs_shape) + .AddIndices(reshaped_rhs_shape) + .CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(), + reshaped_rhs_shape.NumDimensions(), + reshaped_output_shape.NumDimensions()}, + ";")); + } else { + // Mode Broadcast + // cache hint: "B" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddIndices(output_tensor->Shape()) + .CacheHint("B"); + } + + return context.RunProgram(program); +} + +#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public BinaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_KERNEL_2(OP_TYPE, VERSION, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +WEBGPU_BINARY_IMPL(Add, "a + b") +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Div, "a / b") +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Mul, "a * b") +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Sub, "a - b") +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(output_value_t(a), output_value_t(b)))") +WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL_2(Pow, 15, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Equal, "vec4(a == b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 7, 10, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 11, 12, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 13, 18, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Equal, 19, Equal, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Greater, "vec4(a > b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 7, 8, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 9, 12, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Greater, 13, Greater, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Less, "vec4(a < b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 7, 8, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 9, 12, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Less, 13, Less, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(GreaterOrEqual, "vec4(a >= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(GreaterOrEqual, 16, GreaterOrEqual, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(LessOrEqual, "vec4(a <= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h new file mode 100644 index 0000000000000..84cbcdf3244d8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BinaryElementwiseProgram final : public Program { + public: + BinaryElementwiseProgram(const std::string& kernel_name, + const std::string& expression, + const bool is_broadcast, + const bool is_lhs_scalar, + const bool is_rhs_scalar, + const bool vectorize) : Program{kernel_name}, + expression_{expression}, + is_broadcast_{is_broadcast}, + is_lhs_scalar_{is_lhs_scalar}, + is_rhs_scalar_{is_rhs_scalar}, + vectorize_{vectorize} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + bool is_broadcast_; + bool is_lhs_scalar_; + bool is_rhs_scalar_; + bool vectorize_; +}; + +class BinaryElementwise : public WebGpuKernel { + public: + BinaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + + private: + std::string kernel_name_; + std::string expression_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index b4b397b2c4b5f..870dd3df24c73 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -9,8 +9,8 @@ namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); - const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); shader.AppendImplementation(additional_impl_); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), " let a = ", input.GetByOffset("global_idx"), ";\n ", @@ -98,7 +98,7 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -113,7 +113,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} { // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); @@ -154,7 +154,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -180,7 +180,7 @@ class Clip final : public UnaryElementwise { : UnaryElementwise{info, "Clip", std::is_same_v ? ClipF16Impl : ClipImpl, - "", ShaderVariable::UseElementTypeAlias} {} + "", ShaderUsage::UseElementTypeAlias} {} Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { const auto* clip_min_tensor = context.Input(1); @@ -240,7 +240,7 @@ class LinearUnit : public UnaryElementwise { const std::string& expression, const std::string& additional_impl, float default_alpha) - : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} { + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { info.GetAttrOrDefault("alpha", &alpha_, default_alpha); } @@ -269,14 +269,14 @@ class Gelu : public UnaryElementwise { "Gelu", info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, - ShaderVariable::UseValueTypeAlias} { + ShaderUsage::UseValueTypeAlias} { cache_hint = info.GetAttrOrDefault("approximate", "none"); } }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index de85c18da117a..70fa81d21f95d 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -12,7 +12,7 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage) + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage) : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { } @@ -26,7 +26,7 @@ class UnaryElementwiseProgram final : public Program { private: std::string_view expression_; std::string_view additional_impl_; - ShaderVariable::Usage additional_usage_; + ShaderUsage additional_usage_; }; // TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch @@ -38,11 +38,11 @@ class UnaryElementwise : public WebGpuKernel { const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "", - ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info}, - kernel_name_{kernel_name}, - expression_{expression}, - additional_impl_{additional_impl}, - additional_usage_{usage} {} + ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} protected: std::string cache_hint; @@ -57,7 +57,7 @@ class UnaryElementwise : public WebGpuKernel { std::string kernel_name_; std::string expression_; std::string additional_impl_; - ShaderVariable::Usage additional_usage_; + ShaderUsage additional_usage_; }; constexpr const char ErfImpl[] = R"( diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index f12f6fb8a01c4..21c63f75d26d5 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -263,6 +263,16 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list output return *this; } +ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { + indices_.emplace_back(shape); + return *this; +} + ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { return SetDispatchGroupSize(x, 1, 1); } @@ -309,4 +319,4 @@ ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list WORKGROUP_SIZE = 64; -// represents the scope of a variable in a shader program. -// -// this is not a full list of all possible variable scopes in shader programs. -// it only includes what are used in WebGPU EP. -enum class ProgramVariableScope { - Input = 0, // storage buffer variable with access mode "read" - Output = 1, // storage buffer variable with access mode "read_write" - Local = 2, // local variable - - Count // should always be the last element -}; - // data type of variable // // this is not a full list of all possible data types in shader programs. @@ -265,6 +253,10 @@ class ProgramBase { ProgramBase& AddOutput(ProgramOutput&& output); // add multiple program outputs ProgramBase& AddOutputs(std::initializer_list outputs); + // add a program variable for indices + ProgramBase& AddIndices(const TensorShape& shape); + // add a program variable for indices + ProgramBase& AddIndices(TensorShape&& shape); // set the size of dispatch groups. Y and Z are 1 if not specified. ProgramBase& SetDispatchGroupSize(uint32_t x); @@ -330,6 +322,7 @@ class ProgramBase { inline const std::string& CacheHint() const { return cache_hint_; } inline const std::vector& Inputs() const { return inputs_; } inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Indices() const { return indices_; } inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } @@ -351,6 +344,7 @@ class ProgramBase { std::string cache_hint_; std::vector inputs_; std::vector outputs_; + std::vector indices_; uint32_t dispatch_group_size_x_; uint32_t dispatch_group_size_y_; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 3e4fbd33a6bdf..297d211ff1262 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -60,7 +60,9 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); - ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); // code is a large std::string that contains the final shader code std::string code; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 64ed98c78507b..c229e821cbf8c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -78,24 +78,33 @@ Status ShaderHelper::Init() { return Status::OK(); } -const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ShaderVariable::Usage usage) { - const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); +const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { + const size_t input_index = input_vars_.size(); ORT_ENFORCE(input_index < program_.Inputs().size(), "Too many inputs in the program (", program_.Inputs().size(), ")"); const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape : program_.Inputs()[input_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Input, name, usage, dims); + return AddVariableImpl(true, name, usage, dims); } -const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ShaderVariable::Usage usage) { - const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); +const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) { + const size_t output_index = output_vars_.size(); ORT_ENFORCE(output_index < program_.Outputs().size(), "Too many outputs in the program (", program_.Outputs().size(), ")"); const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape : program_.Outputs()[output_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Output, name, usage, dims); + return AddVariableImpl(false, name, usage, dims); +} + +const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) { + const size_t indices_index = indices_vars_.size(); + return *indices_vars_.emplace_back( + std::make_unique(name, + ProgramVariableDataType::InvalidType, + use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None, + program_.Indices()[indices_index])); } #ifndef NDEBUG // if debug build @@ -162,7 +171,7 @@ Status ValidateVariableShape(const TensorShape& origin_shape, } // Validate if the dependency and variable usage match -Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderVariable::Usage usage, bool is_input) { +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderUsage usage, bool is_input) { bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; @@ -172,7 +181,7 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh "Dependency cannot set for both \"Rank\" and \"Shape\"."); // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. - ORT_RETURN_IF(dependency_shape && (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform, + ORT_RETURN_IF(dependency_shape && (usage & ShaderUsage::UseUniform), "Dependency is set for \"Shape\", using uniform for shape is not allowed."); // for input variable, check is more strict. @@ -180,11 +189,11 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh if (is_input) { // if dependency is not set for type, should not use type alias for element and value. // storage type is always used. so setting not depending on type is at user's own risk. - ORT_RETURN_IF(!dependency_type && (usage & (ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias)), + ORT_RETURN_IF(!dependency_type && (usage & (ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias)), "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); // if dependency is not set for rank and shape, the shader should not use shape and stride. - ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderVariable::UseShapeAndStride), + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderUsage::UseShapeAndStride), "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); } @@ -192,7 +201,7 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh } } // namespace -Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), input.use_override_shape, @@ -202,7 +211,7 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar return Status::OK(); } -Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), output.use_override_shape, @@ -215,93 +224,97 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV #endif // NDEBUG -const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, - const std::string& name, - ShaderVariable::Usage usage, - const TensorShape& dims) { - if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { - ORT_ENFORCE(vars_[std::underlying_type::type(ProgramVariableScope::Input)].size() + - vars_[std::underlying_type::type(ProgramVariableScope::Output)].size() < - limits_.maxStorageBuffersPerShaderStage, - "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - } +const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims) { + ORT_ENFORCE(input_vars_.size() + output_vars_.size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - auto& vars = vars_[std::underlying_type::type(scope)]; ProgramVariableDataType type = ProgramVariableDataType::InvalidType; + auto& vars = is_input ? input_vars_ : output_vars_; - if (scope == ProgramVariableScope::Input) { + if (is_input) { const auto& input = program_.Inputs()[vars.size()]; type = input.var_type; - } else if (scope == ProgramVariableScope::Output) { + } else { const auto& output = program_.Outputs()[vars.size()]; type = output.var_type; - } else { - ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); return *var; } -Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { - const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - - // Validate input/output as dependencies of shape_uniforms - ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), - "Mismatched input variable count. Shader: ", input_vars.size(), ", Program: ", program_.Inputs().size()); - ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), - "Mismatched output variable count. Shader: ", output_vars.size(), ", Program: ", program_.Outputs().size()); - - for (size_t i = 0; i < input_vars.size(); i++) { +Status ShaderHelper::ValidateShapeForInputs() const { + // Validate input as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size()); + for (size_t i = 0; i < input_vars_.size(); i++) { #ifndef NDEBUG // if debug build // Validate input shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars_[i])); #endif // check input dependencies with actual usages. - auto usage = input_vars[i]->usage_; - bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto usage = input_vars_[i]->usage_; auto dependency = program_.Inputs()[i].dependency; bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; - if (use_uniform) { - ORT_RETURN_IF_NOT((use_rank || input_vars[i]->rank_ < 2) && !use_shape, - "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); - } else { - ORT_RETURN_IF_NOT(use_shape, - "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); - // If you want neither hard-coded shape nor shape uniform, set UseUniform with a flattened shape (rank=1). - // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars_[i]->rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, use a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } } } + return Status::OK(); +} + +Status ShaderHelper::ValidateShapeForOutputs() const { + // Validate output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size()); - for (size_t i = 0; i < output_vars.size(); i++) { + for (size_t i = 0; i < output_vars_.size(); i++) { #ifndef NDEBUG // if debug build // Validate output shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars_[i])); #endif // check output dependencies with actual usages. - auto usage = output_vars[i]->usage_; - bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto usage = output_vars_[i]->usage_; auto dependency = program_.Outputs()[i].dependency; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; - if (use_uniform) { - // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not - // necessarily a part of the cache key. - ORT_RETURN_IF_NOT(!use_shape, - "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); - } else { - ORT_RETURN_IF_NOT(use_shape, - "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } } } return Status::OK(); } +Status ShaderHelper::ValidateIndices() const { + ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(), + "Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size()); + + return Status::OK(); +} + Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -362,12 +375,10 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // Input/output variables // size_t variable_count = 0; - const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - for (const auto& input : input_vars) { + for (const auto& input : input_vars_) { ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; } - const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - for (const auto& output : output_vars) { + for (const auto& output : output_vars_) { ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; } @@ -378,22 +389,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // store shape uniform ranks in shape_uniform_ranks bool use_any_shape_uniform = false; ORT_ENFORCE(shape_uniform_ranks.size() == 0); - shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); + shape_uniform_ranks.reserve(input_vars_.size() + output_vars_.size() + indices_vars_.size()); - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input->usage_ & ShaderVariable::UseUniform) && - (input->usage_ & ShaderVariable::UseShapeAndStride) && + for (const auto& input : input_vars_) { + bool use_uniform = (input->usage_ & ShaderUsage::UseUniform) && + (input->usage_ & ShaderUsage::UseShapeAndStride) && input->rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output->usage_ & ShaderVariable::UseUniform) && - (output->usage_ & ShaderVariable::UseShapeAndStride) && + for (const auto& output : output_vars_) { + bool use_uniform = (output->usage_ & ShaderUsage::UseUniform) && + (output->usage_ & ShaderUsage::UseShapeAndStride) && output->rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); } + for (const auto& indices : indices_vars_) { + bool use_uniform = (indices->usage_ & ShaderUsage::UseUniform) && + (indices->usage_ & ShaderUsage::UseShapeAndStride) && + indices->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? indices->rank_ : 0); + } if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), program_.UniformVariables().cend(), @@ -430,9 +448,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } }; - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + for (const auto& input : input_vars_) { const size_t rank = input->rank_; - if (rank > 0 && (input->usage_ & ShaderVariable::Usage::UseUniform) && (input->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + if (rank > 0 && (input->usage_ & ShaderUsage::UseUniform) && (input->usage_ & ShaderUsage::UseShapeAndStride)) { std::string shape = input->name_ + "_shape"; std::string stride = input->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -440,9 +458,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + for (const auto& output : output_vars_) { const size_t rank = output->rank_; - if (rank > 0 && (output->usage_ & ShaderVariable::Usage::UseUniform) && (output->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + if (rank > 0 && (output->usage_ & ShaderUsage::UseUniform) && (output->usage_ & ShaderUsage::UseShapeAndStride)) { std::string shape = output->name_ + "_shape"; std::string stride = output->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -450,6 +468,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } } + for (const auto& indices : indices_vars_) { + const size_t rank = indices->rank_; + if (rank > 0 && (indices->usage_ & ShaderUsage::UseUniform) && (indices->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = indices->name_ + "_shape"; + std::string stride = indices->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { const auto& uniform_def = program_metadata_.uniform_variables[i]; const auto& uniform_value = program_.UniformVariables()[i]; @@ -465,10 +493,14 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // Indices helper // ss << "\n"; - for (const auto& var_group : vars_) { - for (const auto& var : var_group) { - var->Impl(ss); - } + for (const auto& var : input_vars_) { + var->Impl(ss); + } + for (const auto& var : output_vars_) { + var->Impl(ss); + } + for (const auto& var : indices_vars_) { + var->Impl(ss); } ss << "\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 811ae3cfa15cc..bdc14669cfb51 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -80,14 +80,17 @@ class ShaderHelper final { // Add an input variable to the shader. // // depending on the usage of the variable, additional code may be generated. - const ShaderVariable& AddInput(const std::string& name, - ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + const ShaderVariableHelper& AddInput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); // Add an output variable to the shader. // // depending on the usage of the variable, additional code may be generated. - const ShaderVariable& AddOutput(const std::string& name, - ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + const ShaderVariableHelper& AddOutput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an indices variable to the shader. + const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); // Append additional implementation code to the shader. // @@ -136,17 +139,19 @@ class ShaderHelper final { } } - const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, - const std::string& name, - ShaderVariable::Usage usage, - const TensorShape& dims); + const ShaderVariableHelper& AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims); #ifndef NDEBUG // if debug build - Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; - Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; + Status ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const; #endif - Status ValidateShapeForInputsAndOutputs() const; + Status ValidateShapeForInputs() const; + Status ValidateShapeForOutputs() const; + Status ValidateIndices() const; // Generate source code. // @@ -171,7 +176,9 @@ class ShaderHelper final { const ProgramBase& program_; const ProgramMetadata& program_metadata_; - std::array>, static_cast(ProgramVariableScope::Count)> vars_; + std::vector> input_vars_; + std::vector> output_vars_; + std::vector> indices_vars_; std::ostringstream additional_implementation_; std::ostringstream body_; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 07c5915be466b..f2a5b049b4777 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -76,7 +76,7 @@ inline std::string GetIndicesType(int rank) { } // namespace -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) +ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) : name_(name), type_(type), num_components_{NumberOfComponents(type)}, @@ -86,30 +86,33 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty indices_type_{GetIndicesType(rank_)}, value_type_alias_{name_ + "_value_t"}, element_type_alias_{name_ + "_element_t"}, - indices_type_alias_{name_ + "_indices_t"} { + indices_type_alias_{name_ + "_indices_t"} {} + +ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : ShaderIndicesHelper{name, type, usage, dims} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } -void ShaderVariable::Impl(std::ostringstream& ss) const { +void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Start generating code - const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; - const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - if (usage_ & UseValueTypeAlias) { + if (usage_ & ShaderUsage::UseValueTypeAlias) { SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } - if (usage_ & UseIndicesTypeAlias) { + if (usage_ & ShaderUsage::UseIndicesTypeAlias) { SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } - if (usage_ & UseElementTypeAlias) { + if (usage_ & ShaderUsage::UseElementTypeAlias) { SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (use shape and stride is enabled) - if (!(usage_ & UseUniform) && (usage_ & UseShapeAndStride) && rank_ > 0) { + if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; @@ -138,7 +141,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn o2i_{name}" - if (usage_ & UseOffsetToIndices) { + if (usage_ & ShaderUsage::UseOffsetToIndices) { if (rank_ >= 2) { SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); SS(" var indices: ", IndicesType(), ";\n"); @@ -157,7 +160,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn i2o_{name}" - if (usage_ & UseIndicesToOffset) { + if (usage_ & ShaderUsage::UseIndicesToOffset) { if (rank_ >= 2) { SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); @@ -170,7 +173,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn {res_name}_bi2o_{name}" - if (usage_ & UseBroadcastedIndicesToOffset) { + if (usage_ & ShaderUsage::UseBroadcastedIndicesToOffset) { if (rank_ > 0) { for (const auto& broadcasted_result_ptr : broadcasted_to_) { const auto& broadcasted_result = *broadcasted_result_ptr; @@ -190,9 +193,13 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } } +} + +void ShaderVariableHelper::Impl(std::ostringstream& ss) const { + ShaderIndicesHelper::Impl(ss); // Implementation of "fn set_{name}" - if (usage_ & UseSet) { + if (usage_ & ShaderUsage::UseSet) { if (rank_ >= 2) { SS("fn set_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { @@ -209,7 +216,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn set_{name}_by_indices" - if (usage_ & UseSetByIndices) { + if (usage_ & ShaderUsage::UseSetByIndices) { if (rank_ >= 2) { SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); @@ -218,7 +225,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn get_{name}" - if (usage_ & UseGet) { + if (usage_ & ShaderUsage::UseGet) { if (rank_ >= 2) { SS("fn get_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { @@ -235,7 +242,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn get_{name}_by_indices" - if (usage_ & UseGetByIndices) { + if (usage_ & ShaderUsage::UseGetByIndices) { if (rank_ >= 2) { SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); @@ -244,7 +251,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } -std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { +std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -270,7 +277,7 @@ std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { return ss.str(); } -std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string_view value) const { +std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -294,20 +301,20 @@ std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string return ss.str(); } -std::string_view ShaderVariable::StorageType() const { +std::string_view ShaderVariableHelper::StorageType() const { return STORAGE_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::ValueType() const { - return (usage_ & UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; +std::string_view ShaderVariableHelper::ValueType() const { + return (usage_ & ShaderUsage::UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::ElementType() const { - return (usage_ & UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; +std::string_view ShaderVariableHelper::ElementType() const { + return (usage_ & ShaderUsage::UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::IndicesType() const { - return (usage_ & UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; +std::string_view ShaderIndicesHelper::IndicesType() const { + return (usage_ & ShaderUsage::UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; } } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 71822a61f7a77..326c6814410de 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -37,9 +37,8 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; } -class ShaderVariable { - public: - enum Usage : uint32_t { +struct ShaderUsage { + enum : uint32_t { None = 0, // no usage. this means no additional implementation code will be generated. UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) @@ -53,17 +52,21 @@ class ShaderVariable { UseGet = 1024, // use implementation of fn get_{name} UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices UseUniform = 32768, // use uniform for shape and stride - }; + } usage; - ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); + ShaderUsage(decltype(usage) usage) : usage{usage} {} + ShaderUsage(uint32_t usage) : usage{usage} {} - ShaderVariable(ShaderVariable&&) = default; - ShaderVariable& operator=(ShaderVariable&&) = default; + explicit operator bool() { + return usage != None; + } +}; - // get the name of the variable. - inline std::string_view Name() const { return name_; } +// A helper class to make it easier to generate shader code related to indices calculation. +class ShaderIndicesHelper { + public: + ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); - // get the number of components of the variable. inline int NumComponents() const { return num_components_; } // create a WGSL expression ({varname}_indices_t) for getting indices from offset. @@ -77,7 +80,7 @@ class ShaderVariable { // create a WGSL expression (u32) for getting original offset from broadcasted indices. // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. // \param broadcasted_result: the broadcasted result variable. - inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const; + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const; // create a WGSL expression ({varname}_indices_t) as an indices literal // \param init: a list of indices values. @@ -97,6 +100,41 @@ class ShaderVariable { template inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; + protected: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); + + void Impl(std::ostringstream& ss) const; + + std::string_view IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; // for variable + int num_components_; // for variable + int rank_; + TensorShape dims_; + + mutable ShaderUsage usage_; + mutable std::set broadcasted_to_; + + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + + friend class ShaderHelper; +}; + +// A helper class to make it easier to generate shader code related to a variable setting/getting and its indices calculation. +class ShaderVariableHelper : public ShaderIndicesHelper { + public: + ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderVariableHelper(ShaderVariableHelper&&) = default; + ShaderVariableHelper& operator=(ShaderVariableHelper&&) = default; + // create a WGSL statement for setting data at the given indices. // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). template @@ -128,12 +166,7 @@ class ShaderVariable { inline std::string GetByOffset(TOffset&& offset) const; private: - friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); - friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); - friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); - friend ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b); - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); void Impl(std::ostringstream& ss) const; @@ -142,39 +175,23 @@ class ShaderVariable { std::string_view StorageType() const; std::string_view ValueType() const; std::string_view ElementType() const; - std::string_view IndicesType() const; - - std::string name_; - ProgramVariableDataType type_; - int num_components_; - int rank_; - TensorShape dims_; - - mutable Usage usage_; - mutable std::set broadcasted_to_; - - // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. - std::string indices_type_; - - // the alias for the types - std::string value_type_alias_; - std::string element_type_alias_; - std::string indices_type_alias_; friend class ShaderHelper; }; -inline ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage)((uint32_t&)a | (uint32_t&)b); +inline ShaderUsage operator|(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage | (uint32_t)b.usage; } -inline ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage)((uint32_t&)a & (uint32_t&)b); +inline ShaderUsage operator&(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage & (uint32_t)b.usage; } -inline ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage&)((uint32_t&)a |= (uint32_t&)b); +inline ShaderUsage& operator|=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage |= (uint32_t)b.usage; + return a; } -inline ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage&)((uint32_t&)a &= (uint32_t&)b); +inline ShaderUsage& operator&=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage &= (uint32_t)b.usage; + return a; } namespace detail { @@ -192,20 +209,24 @@ std::string pass_as_string(T&& v) { } } // namespace detail -inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { - usage_ |= UseOffsetToIndices | UseShapeAndStride; +inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { + usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{offset_expr} : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } -inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { - usage_ |= UseIndicesToOffset | UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesToOffset(std::string_view indices_expr) const { + usage_ |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{indices_expr} : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } -inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { - usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; +inline std::string ShaderIndicesHelper::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const { + ORT_ENFORCE(broadcasted_result.num_components_ == -1 || + num_components_ == -1 || + broadcasted_result.num_components_ == num_components_, + "number of components should be the same for 2 variables to calculate"); + usage_ |= ShaderUsage::UseBroadcastedIndicesToOffset | ShaderUsage::UseShapeAndStride; broadcasted_to_.insert(&broadcasted_result); return rank_ == 0 ? "0" @@ -213,8 +234,8 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view i } template -inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::Indices(TIndices&&... indices_args) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ == 0 ? "0" : MakeStringWithClassicLocale(IndicesType(), "(", @@ -223,77 +244,77 @@ inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { } template -inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template -inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{indices_var} : GetElementAt(indices_var, idx_expr, rank_); } template -inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) const { +inline std::string ShaderVariableHelper::SetByOffset(TOffset&& offset, TValue&& value) const { return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); } template -inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::Set(TIndicesAndValue&&... args) const { + usage_ |= ShaderUsage::UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); if constexpr (sizeof...(TIndicesAndValue) == 1) { return SetByOffset("0", std::forward(args)...); } else if constexpr (sizeof...(TIndicesAndValue) == 2) { return SetByOffset(std::forward(args)...); } else { - usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseSet | ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("set_", name_, '(', absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), ");"); } } -inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= ShaderUsage::UseShapeAndStride; if (rank_ < 2) { return SetByOffset(indices_var, value); } else { - usage_ |= UseSetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); } } template -inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { +inline std::string ShaderVariableHelper::GetByOffset(TOffset&& offset) const { return GetByOffsetImpl(detail::pass_as_string(offset)); } template -inline std::string ShaderVariable::Get(TIndices&&... indices) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::Get(TIndices&&... indices) const { + usage_ |= ShaderUsage::UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); if constexpr (sizeof...(TIndices) == 0) { return GetByOffset("0"); } else if constexpr (sizeof...(TIndices) == 1) { return GetByOffset(std::forward(indices)...); } else { - usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseGet | ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("get_", name_, '(', absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), ')'); } } -inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::GetByIndices(std::string_view indices_var) const { + usage_ |= ShaderUsage::UseShapeAndStride; if (rank_ < 2) { return GetByOffset(indices_var); } else { - usage_ |= UseGetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); } } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 45084472d3537..a106583651885 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,8 +11,8 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); - const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 68af858d515c2..b620e83843b2f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -61,8 +61,8 @@ const std::string AppendPermFunction(gsl::span perm) { } Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); shader.AppendImplementation(AppendPermFunction(this->perm_)); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), " let indices = ", output.OffsetToIndices("global_idx"), diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 11a337cd3e37e..66b1c2c7fafac 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -229,17 +229,16 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog std::vector shape_uniforms; shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); if (ValidationMode() >= ValidationMode::Basic) { - ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size() + program.Indices().size(), "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), - ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + ") does not match current program (input: ", inputs.size(), + ", output: ", outputs.size(), + ", indices: ", program.Indices().size(), ")"); } - for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { + + auto append_shape_uniforms = [&shape_uniforms, program_artifact](size_t i, const TensorShape& shape) { SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; if (expected_rank > 0) { - const auto& shape = i < inputs.size() ? (inputs[i].use_override_shape ? inputs[i].override_shape - : inputs[i].tensor->Shape()) - : (outputs[i - inputs.size()].use_override_shape ? outputs[i - inputs.size()].override_shape - : outputs[i - inputs.size()].tensor->Shape()); ORT_RETURN_IF(expected_rank != shape.NumDimensions(), "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, ", Actual: ", shape.NumDimensions()); @@ -258,6 +257,19 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog shape_uniforms.emplace_back(gsl::make_span(stride)); } } + return Status::OK(); + }; + + for (size_t i = 0; i < inputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i, + inputs[i].use_override_shape ? inputs[i].override_shape : inputs[i].tensor->Shape())); + } + for (size_t i = 0; i < outputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size(), + outputs[i].use_override_shape ? outputs[i].override_shape : outputs[i].tensor->Shape())); + } + for (size_t i = 0; i < program.Indices().size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size() + outputs.size(), program.Indices()[i])); } const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 444f07e1664b8..abd471578146c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -456,36 +456,36 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(20, Gelu), // // binary - math - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), - // KERNEL_CREATE_INFO(14, Add), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), - // KERNEL_CREATE_INFO(14, Sub), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), - // KERNEL_CREATE_INFO(14, Mul), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), - // KERNEL_CREATE_INFO(14, Div), - // KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), - // KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), - // KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), - // KERNEL_CREATE_INFO(15, Pow), - // KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), - // KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), - // KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), - // KERNEL_CREATE_INFO(19, Equal), - // KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), - // KERNEL_CREATE_INFO(13, Greater), - // KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), - // KERNEL_CREATE_INFO(16, GreaterOrEqual), - // KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), - // KERNEL_CREATE_INFO(13, Less), - // KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), - // KERNEL_CREATE_INFO(16, LessOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + KERNEL_CREATE_INFO(19, Equal), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 4ca915dd394c1..4aa3e9c6b37a3 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -369,6 +369,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) { #endif } +TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) { + OpTester test("Add"); + + test.AddInput("A", {2, 2, 2}, + {101.0f, 102.0f, + 103.0f, 104.0f, + + 201.0f, 202.0f, + 203.0f, 204.0f}); + test.AddInput("B", {1, 2, 2}, + {010.0f, 020.0f, + 030.0f, 040.0f}); + test.AddOutput("C", {2, 2, 2}, + {111.0f, 122.0f, + 133.0f, 144.0f, + + 211.0f, 222.0f, + 233.0f, 244.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) { OpTester test("Add");