diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc index cfea39e1464d3..02969ba61e4ab 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc @@ -6,6 +6,7 @@ #include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/nn/im2col_matmul.h" +#include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/activation_util.h" namespace onnxruntime { @@ -52,15 +53,6 @@ bool IsDeviceSupported(const ComputeContextBase& context) { } // namespace -Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - - return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", - WGSL_TEMPLATE_VARIABLE(output, output), - WGSL_TEMPLATE_VARIABLE(src, src)); -} - Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -93,34 +85,16 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context, const bool has_bias = context.InputCount() > 2; const auto* bias = has_bias ? context.Input(2) : nullptr; - // Transpose OIHW Weight to OHWI - // TODO: Move to `Transpose` - // TODO: Use prepack TensorShape weight_shape = weight->Shape(); const uint32_t channel_output = onnxruntime::narrow(weight_shape[0]); const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); const uint32_t kernel_height = onnxruntime::narrow(weight_shape[2]); const uint32_t kernel_width = onnxruntime::narrow(weight_shape[3]); - TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input}; - Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape); - OIHW2OHWIProgram transpose_program{}; - transpose_program.SetWorkgroupSize(64); - - const uint32_t Ci_tiles = CeilDiv(channel_input, 64u); - transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); - - transpose_program.AddInput({weight, - ProgramTensorMetadataDependency::TypeAndRank}); - transpose_program.AddOutput({&ohwi_weight, - ProgramTensorMetadataDependency::TypeAndRank}); - transpose_program.AddUniformVariables({{channel_output}, - {channel_input}, - {kernel_height}, - {kernel_width}, - {Ci_tiles}, - {CeilDiv(kernel_height * kernel_height, 4u)}}); - ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); + // Transpose OIHW Weight to OHWI + // TODO: Use prepack + Tensor ohwi_weight; + ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, weight->Shape(), &ohwi_weight, {0, 2, 3, 1})); // im2col-matmul const TensorShape src_shape = src->Shape(); diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h index ed24100879520..c365cfda0c43b 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h @@ -18,22 +18,6 @@ namespace onnxruntime { namespace webgpu { -// Transpose OIHW Weight to OHWI -class OIHW2OHWIProgram final : public Program { - public: - OIHW2OHWIProgram() : Program("OIHW2OHWI") {} - - Status GenerateShaderCode(ShaderHelper& shader) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"O", ProgramUniformVariableDataType::Uint32}, - {"I", ProgramUniformVariableDataType::Uint32}, - {"H", ProgramUniformVariableDataType::Uint32}, - {"W", ProgramUniformVariableDataType::Uint32}, - {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, - {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); -}; - class Im2ColMatMulProgram final : public Program { public: Im2ColMatMulProgram(bool has_bias, diff --git a/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template b/onnxruntime/core/providers/webgpu/tensor/oihw_to_ohwi.wgsl.template similarity index 100% rename from onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template rename to onnxruntime/core/providers/webgpu/tensor/oihw_to_ohwi.wgsl.template diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 7b1c1d8888a19..230d172d7404e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/span_utils.h" #include "core/common/inlined_containers.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/webgpu/tensor/transpose.h" @@ -9,6 +10,30 @@ #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/webgpu_utils.h" +namespace { +bool AreSpansEqual(gsl::span a, gsl::span b) { + if (a.size() != b.size()) { + return false; + } + + return std::equal(a.begin(), a.end(), b.begin()); +} + +auto SqueezeShape(const gsl::span& shape, + const gsl::span& adjusted_perm, + onnxruntime::TensorShapeVector& new_shape, + onnxruntime::TensorShapeVector& new_perm) { + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; +} // namespace + namespace onnxruntime { namespace webgpu { ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -47,19 +72,14 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -auto SqueezeShape(const gsl::span& shape, - const gsl::span& adjusted_perm, - TensorShapeVector& new_shape, - TensorShapeVector& new_perm) { - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] != 1) { - new_shape.push_back(shape[i]); - } - if (shape[adjusted_perm[i]] != 1) { - new_perm.push_back(adjusted_perm[i]); - } - } -}; +Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(shader, "tensor/oihw_to_ohwi.wgsl.template", + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(src, src)); +} Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); @@ -106,12 +126,52 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, const auto& input_shape = input.Shape(); const auto& input_dims = input_shape.GetDims(); int32_t rank = static_cast(input_shape.NumDimensions()); - TensorShapeVector output_dims(rank); for (int32_t i = 0; i < rank; i++) { output_dims[i] = input_dims[permutations[i]]; } + TensorShape output_shape(output_dims); + + // Check if `OIHW2OHWIProgram` can be applied. + // + // `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW + // to OHWI format, utilizing workgroup tiling to maximize bandwidth through + // coalesced reads and writes. While variable names reflect this origin for + // simplicity, the shader is now generalized for broader use, supporting any + // permutation equivalent to {0, 2, 3, 1}. + // + // TODO: Extend support to 2D and 3D transpositions. + if (AreSpansEqual(permutations, AsSpan({0, 2, 3, 1}))) { + const uint32_t channel_output = onnxruntime::narrow(input_shape[0]); + const uint32_t channel_input = onnxruntime::narrow(input_shape[1]); + const uint32_t kernel_height = onnxruntime::narrow(input_shape[2]); + const uint32_t kernel_width = onnxruntime::narrow(input_shape[3]); + + // Calculate tiling for the input channel dimension (tiled by 64) + const uint32_t input_channel_tiles = CeilDiv(channel_input, 64u); + const uint32_t dispatch_size = channel_output * input_channel_tiles; + + // Threshold check: Only apply if the workload is large enough to saturate + // GPU compute units. For small tensors, the overhead of the transpose + // outweighs the gain. + if (dispatch_size >= 128u) { + OIHW2OHWIProgram transpose_program{}; + transpose_program.SetWorkgroupSize(64); + transpose_program.SetDispatchGroupSize(dispatch_size); + transpose_program.AddInput({&input, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddOutput({&output, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddUniformVariables({{channel_output}, + {channel_input}, + {kernel_height}, + {kernel_width}, + {input_channel_tiles}, + {CeilDiv(kernel_height * kernel_width, 4u)}}); + return context.RunProgram(transpose_program); + } + } TensorShapeVector new_shape{}; TensorShapeVector new_perm{}; @@ -120,7 +180,6 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, const bool channels_first = new_perm == TensorShapeVector({3, 1, 2}); const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; auto new_input_shape = input_shape; - TensorShape new_output_shape(output_dims); if (use_shared) { new_input_shape = channels_last @@ -128,7 +187,7 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, : channels_first ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) : new_shape; - new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); } uint32_t output_size = onnxruntime::narrow(input_shape.Size()); @@ -137,13 +196,13 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, program .CacheHint(absl::StrJoin(permutations, "-")) .AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) - .AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .AddOutputs({{&output, ProgramTensorMetadataDependency::None, output_shape, 1}}) .AddUniformVariables({{output_size}}); if (use_shared) { program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); - program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), - static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))); + program.SetDispatchGroupSize(static_cast((output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))); } else { program.SetWorkgroupSize(64u); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 5e9ccc6750cd6..abd3b2bb79e47 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -11,6 +11,22 @@ namespace onnxruntime { namespace webgpu { +// Transpose OIHW Weight to OHWI +class OIHW2OHWIProgram final : public Program { + public: + OIHW2OHWIProgram() : Program("OIHW2OHWI") {} + + Status GenerateShaderCode(ShaderHelper& shader) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"O", ProgramUniformVariableDataType::Uint32}, + {"I", ProgramUniformVariableDataType::Uint32}, + {"H", ProgramUniformVariableDataType::Uint32}, + {"W", ProgramUniformVariableDataType::Uint32}, + {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, + {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); +}; + class Transpose final : public WebGpuKernel, public TransposeBase { public: Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {