Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 5 additions & 31 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/nn/im2col_matmul.h"
#include "core/providers/webgpu/nn/conv.h"
#include "core/providers/webgpu/nn/activation_util.h"

namespace onnxruntime {
Expand Down Expand Up @@ -52,15 +53,6 @@

} // namespace

Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src));
}

Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -93,34 +85,16 @@
const bool has_bias = context.InputCount() > 2;
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;

// Transpose OIHW Weight to OHWI
// TODO: Move to `Transpose`
// TODO: Use prepack
TensorShape weight_shape = weight->Shape();
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]);
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);

TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input};
Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape);
OIHW2OHWIProgram transpose_program{};
transpose_program.SetWorkgroupSize(64);

const uint32_t Ci_tiles = CeilDiv(channel_input, 64u);
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);

transpose_program.AddInput({weight,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddOutput({&ohwi_weight,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddUniformVariables({{channel_output},
{channel_input},
{kernel_height},
{kernel_width},
{Ci_tiles},
{CeilDiv(kernel_height * kernel_height, 4u)}});
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));
// Transpose OIHW Weight to OHWI
// TODO: Use prepack

Check warning on line 95 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:95: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Tensor ohwi_weight;
ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, weight->Shape(), &ohwi_weight, {0, 2, 3, 1}));

// im2col-matmul
const TensorShape src_shape = src->Shape();
Expand Down
16 changes: 0 additions & 16 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,6 @@
namespace onnxruntime {
namespace webgpu {

// Transpose OIHW Weight to OHWI
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
public:
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"O", ProgramUniformVariableDataType::Uint32},
{"I", ProgramUniformVariableDataType::Uint32},
{"H", ProgramUniformVariableDataType::Uint32},
{"W", ProgramUniformVariableDataType::Uint32},
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
};

class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
public:
Im2ColMatMulProgram(bool has_bias,
Expand Down
97 changes: 78 additions & 19 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/span_utils.h"
#include "core/common/inlined_containers.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/tensor/transpose.h"
Expand All @@ -9,6 +10,30 @@
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"

namespace {
bool AreSpansEqual(gsl::span<const size_t> a, gsl::span<const size_t> b) {
if (a.size() != b.size()) {
return false;
}

return std::equal(a.begin(), a.end(), b.begin());
}

auto SqueezeShape(const gsl::span<const int64_t>& shape,
const gsl::span<const size_t>& adjusted_perm,
onnxruntime::TensorShapeVector& new_shape,
onnxruntime::TensorShapeVector& new_perm) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != 1) {
new_shape.push_back(shape[i]);
}
if (shape[adjusted_perm[i]] != 1) {
new_perm.push_back(adjusted_perm[i]);
}
}
};

Check warning on line 34 in onnxruntime/core/providers/webgpu/tensor/transpose.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/transpose.cc:34: You don't need a ; after a } [readability/braces] [4]
} // namespace

namespace onnxruntime {
namespace webgpu {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand Down Expand Up @@ -47,19 +72,14 @@
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Transpose);

auto SqueezeShape(const gsl::span<const int64_t>& shape,
const gsl::span<const size_t>& adjusted_perm,
TensorShapeVector& new_shape,
TensorShapeVector& new_perm) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != 1) {
new_shape.push_back(shape[i]);
}
if (shape[adjusted_perm[i]] != 1) {
new_perm.push_back(adjusted_perm[i]);
}
}
};
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "tensor/oihw_to_ohwi.wgsl.template",
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src));
}

Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
Expand Down Expand Up @@ -106,12 +126,52 @@
const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
int32_t rank = static_cast<int32_t>(input_shape.NumDimensions());

TensorShapeVector output_dims(rank);

for (int32_t i = 0; i < rank; i++) {
output_dims[i] = input_dims[permutations[i]];
}
TensorShape output_shape(output_dims);

// Check if `OIHW2OHWIProgram` can be applied.
//
// `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW
// to OHWI format, utilizing workgroup tiling to maximize bandwidth through
// coalesced reads and writes. While variable names reflect this origin for
// simplicity, the shader is now generalized for broader use, supporting any
// permutation equivalent to {0, 2, 3, 1}.
//
// TODO: Extend support to 2D and 3D transpositions.

Check warning on line 144 in onnxruntime/core/providers/webgpu/tensor/transpose.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/tensor/transpose.cc:144: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (AreSpansEqual(permutations, AsSpan<const size_t>({0, 2, 3, 1}))) {
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(input_shape[0]);
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(input_shape[1]);
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(input_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(input_shape[3]);

// Calculate tiling for the input channel dimension (tiled by 64)
const uint32_t input_channel_tiles = CeilDiv(channel_input, 64u);
const uint32_t dispatch_size = channel_output * input_channel_tiles;

// Threshold check: Only apply if the workload is large enough to saturate
// GPU compute units. For small tensors, the overhead of the transpose
// outweighs the gain.
if (dispatch_size >= 128u) {
OIHW2OHWIProgram transpose_program{};
transpose_program.SetWorkgroupSize(64);
transpose_program.SetDispatchGroupSize(dispatch_size);
transpose_program.AddInput({&input,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddOutput({&output,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddUniformVariables({{channel_output},
{channel_input},
{kernel_height},
{kernel_width},
{input_channel_tiles},
{CeilDiv(kernel_height * kernel_width, 4u)}});
return context.RunProgram(transpose_program);
}
}

TensorShapeVector new_shape{};
TensorShapeVector new_perm{};
Expand All @@ -120,15 +180,14 @@
const bool channels_first = new_perm == TensorShapeVector({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);

if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: new_shape;
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = onnxruntime::narrow<uint32_t>(input_shape.Size());
Expand All @@ -137,13 +196,13 @@
program
.CacheHint(absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, output_shape, 1}})
.AddUniformVariables({{output_size}});

if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
program.SetDispatchGroupSize(static_cast<uint32_t>((output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
} else {
program.SetWorkgroupSize(64u);

Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
namespace onnxruntime {
namespace webgpu {

// Transpose OIHW Weight to OHWI
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
public:
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"O", ProgramUniformVariableDataType::Uint32},
{"I", ProgramUniformVariableDataType::Uint32},
{"H", ProgramUniformVariableDataType::Uint32},
{"W", ProgramUniformVariableDataType::Uint32},
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
};

class Transpose final : public WebGpuKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
Expand Down
Loading