Skip to content

Commit

Permalink
w2
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 16, 2024
1 parent 316e994 commit 4d70ff9
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 49 deletions.
102 changes: 62 additions & 40 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,66 @@
namespace onnxruntime {
namespace webgpu {
Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("input_a", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& c = shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" :
"let a = " + a.GetByOffset("global_idx") + ";\n";
std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" :
"let b = " + b.GetByOffset("global_idx") + ";\n";
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" : "let a = " + a.GetByOffset("global_idx") + ";\n";
std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" : "let b = " + b.GetByOffset("global_idx") + ";\n";
if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) {
if (vectorize_) {
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx * 4") + ";\n"
"let offset_a = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
const std::string a_data = a.Num_Components() == 4 ?
"let a = " + a.GetByOffset("offset_a / 4") + ";\n" :
"let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n";
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx * 4") +
";\n"
"let offset_a = " +
a.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"let offset_b = " +
b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
const std::string a_data = a.NumComponents() == 4 ? "let a = " + a.GetByOffset("offset_a / 4") + ";\n" : "let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n";
get_a_data = common + a_data;
get_b_data = b.Num_Components() == 4 ?
"let b = " + b.GetByOffset("offset_b / 4") + ";\n" :
"let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n";
get_b_data = b.NumComponents() == 4 ? "let b = " + b.GetByOffset("offset_b / 4") + ";\n" : "let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n";
} else {
std::string common = "var outputIndices = " + c.OffsetToIndices("global_idx * 4") + ";\n"
"let offset_a0 = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b0 = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"outputIndices = " + c.OffsetToIndices("global_idx * 4 + 1") + ";\n"
"let offset_a1 = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b1 = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"outputIndices = " + c.OffsetToIndices("global_idx * 4 + 2") + ";\n"
"let offset_a2 = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b2 = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"outputIndices = " + c.OffsetToIndices("global_idx * 4 + 3") + ";\n"
"let offset_a3 = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b3 = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
get_a_data = common + "let a = vec4<input_a_value_t>(" + a.GetByOffset("offset_a0") + ", " +
a.GetByOffset("offset_a1") + ", " +
a.GetByOffset("offset_a2") + ", " +
a.GetByOffset("offset_a3") + ");\n";
get_b_data = "let b = vec4<input_b_value_t>(" + b.GetByOffset("offset_b0") + ", " +
b.GetByOffset("offset_b1") + ", " +
b.GetByOffset("offset_b2") + ", " +
b.GetByOffset("offset_b3") + ");\n";
std::string common = "var outputIndices = " + c.OffsetToIndices("global_idx * 4") +
";\n"
"let offset_a0 = " +
a.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"let offset_b0 = " +
b.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"outputIndices = " +
c.OffsetToIndices("global_idx * 4 + 1") +
";\n"
"let offset_a1 = " +
a.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"let offset_b1 = " +
b.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"outputIndices = " +
c.OffsetToIndices("global_idx * 4 + 2") +
";\n"
"let offset_a2 = " +
a.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"let offset_b2 = " +
b.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"outputIndices = " +
c.OffsetToIndices("global_idx * 4 + 3") +
";\n"
"let offset_a3 = " +
a.BroadcastedIndicesToOffset("outputIndices", c) +
";\n"
"let offset_b3 = " +
b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
get_a_data = common + "let a = vec4<input_a_value_t>(" + a.GetByOffset("offset_a0") + ", " +
a.GetByOffset("offset_a1") + ", " +
a.GetByOffset("offset_a2") + ", " +
a.GetByOffset("offset_a3") + ");\n";
get_b_data = "let b = vec4<input_b_value_t>(" + b.GetByOffset("offset_b0") + ", " +
b.GetByOffset("offset_b1") + ", " +
b.GetByOffset("offset_b2") + ", " +
b.GetByOffset("offset_b3") + ");\n";
}
}

Expand Down Expand Up @@ -105,8 +125,9 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
}

SafeInt<uint32_t> vec_size = (size + 3) / 4;
const std::string expression = output_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 ?
"output_value_t(pow(vec4<f16>(a), vec4<f16>(b)))" : "output_value_t(pow(vec4<f32>(a), vec4<f32>(b)))";
const std::string expression = output_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
? "output_value_t(pow(vec4<f16>(a), vec4<f16>(b)))"
: "output_value_t(pow(vec4<f32>(a), vec4<f32>(b)))";
BinaryElementwiseProgram program{kernel_name_, kernel_name_ == "Pow" ? expression : expression_, is_broadcast, is_lhs_scalar, is_rhs_scalar, vectorize};
program
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
Expand All @@ -124,7 +145,8 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
program
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4 ? 4 : 1},
{rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4 ? 4 : 1}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, 4}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}})
.AddIndices(output_tensor->Shape())
.CacheHint(std::to_string(vectorize));
}

Expand Down Expand Up @@ -163,7 +185,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
kWebGpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", TYPE) \
.TypeConstraint("T1", TYPE1), \
.TypeConstraint("T1", TYPE1), \
KERNEL_CLASS);

#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,16 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list<ProgramOutput> output
return *this;
}

ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) {
indices_.emplace_back(shape);
return *this;
}

ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) {
return SetDispatchGroupSize(x, 1, 1);
}
Expand Down Expand Up @@ -309,4 +319,4 @@ ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list<ProgramO
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ class ProgramBase {
ProgramBase& AddOutput(ProgramOutput&& output);
// add multiple program outputs
ProgramBase& AddOutputs(std::initializer_list<ProgramOutput> outputs);
// add a program variable for indices
ProgramBase& AddIndices(const TensorShape& shape);
// add a program variable for indices
ProgramBase& AddIndices(TensorShape&& shape);

// set the size of dispatch groups. Y and Z are 1 if not specified.
ProgramBase& SetDispatchGroupSize(uint32_t x);
Expand Down Expand Up @@ -318,6 +322,7 @@ class ProgramBase {
inline const std::string& CacheHint() const { return cache_hint_; }
inline const std::vector<ProgramInput>& Inputs() const { return inputs_; }
inline const std::vector<ProgramOutput>& Outputs() const { return outputs_; }
inline const std::vector<TensorShape>& Indices() const { return indices_; }
inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; }
inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; }
inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; }
Expand All @@ -339,6 +344,7 @@ class ProgramBase {
std::string cache_hint_;
std::vector<ProgramInput> inputs_;
std::vector<ProgramOutput> outputs_;
std::vector<TensorShape> indices_;

uint32_t dispatch_group_size_x_;
uint32_t dispatch_group_size_y_;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/program_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ Status ProgramManager::Build(const ProgramBase& program,

ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper));

ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs());
ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices());

// code is a large std::string that contains the final shader code
std::string code;
Expand Down
30 changes: 25 additions & 5 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, Sha
return AddVariableImpl(false, name, usage, dims);
}

const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) {
const size_t indices_index = indices_vars_.size();
return *indices_vars_.emplace_back(
std::make_unique<ShaderIndicesHelper>(name,
ProgramVariableDataType::InvalidType,
use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None,
program_.Indices()[indices_index]));
}

#ifndef NDEBUG // if debug build
namespace {
// Validate if the tensor element type matches the program variable data type
Expand Down Expand Up @@ -237,13 +246,10 @@ const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input,
return *var;
}

Status ShaderHelper::ValidateShapeForInputsAndOutputs() const {
// Validate input/output as dependencies of shape_uniforms
Status ShaderHelper::ValidateShapeForInputs() const {
// Validate input as dependencies of shape_uniforms
ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(),
"Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size());
ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(),
"Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size());

for (size_t i = 0; i < input_vars_.size(); i++) {
#ifndef NDEBUG // if debug build
// Validate input shape
Expand All @@ -266,6 +272,13 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const {
// This will not generate any shape variables in the shader, can you can only use offset to set/get values.
}
}
return Status::OK();
}

Status ShaderHelper::ValidateShapeForOutputs() const {
// Validate output as dependencies of shape_uniforms
ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(),
"Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size());

for (size_t i = 0; i < output_vars_.size(); i++) {
#ifndef NDEBUG // if debug build
Expand All @@ -291,6 +304,13 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const {
return Status::OK();
}

Status ShaderHelper::ValidateIndices() const {
ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(),
"Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size());

return Status::OK();
}

Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const {
std::ostringstream ss;
ss.imbue(std::locale::classic());
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/webgpu/shader_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ShaderHelper final {
const ShaderVariableHelper& AddOutput(const std::string& name,
ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform);

// Add an indices variable to the shader.
const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform);

// Append additional implementation code to the shader.
//
// can be called multiple times.
Expand Down Expand Up @@ -146,7 +149,9 @@ class ShaderHelper final {
Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const;
#endif

Status ValidateShapeForInputsAndOutputs() const;
Status ValidateShapeForInputs() const;
Status ValidateShapeForOutputs() const;
Status ValidateIndices() const;

// Generate source code.
//
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ShaderIndicesHelper {
public:
ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims);

inline int Num_Components() const { return num_components_; }
inline int NumComponents() const { return num_components_; }

// create a WGSL expression ({varname}_indices_t) for getting indices from offset.
// \param offset: a WGSL expression (u32) representing the offset.
Expand Down

0 comments on commit 4d70ff9

Please sign in to comment.