Skip to content

Commit 27043c9

Browse files
committed
format
1 parent 5a0a106 commit 27043c9

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

+20-20
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,23 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
6464
const uint32_t block_size = 32;
6565
const uint32_t blocks_per_tile = tile_size / block_size;
6666
shader.AdditionalImplementation() << "var<workgroup> sub_a: array<input_a_value_t, " << a_length_per_tile << ">;\n"
67-
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, "<< WorkgroupSizeY() << ">;\n";
67+
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">;\n";
6868
std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY());
6969
shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n"
7070
<< " let col = output_indices[2];\n"
7171
" let row = output_indices[1];\n"
7272
" let batch = output_indices[0];\n"
7373
" let n_blocks_per_col = uniforms.input_b_shape[1];\n"
74-
" let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
74+
<< " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
7575
<< " // Loop over shared dimension.\n"
7676
" for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n"
77-
" let a_col_start = tile * " << a_length_per_tile << ";\n"
78-
" // load one tile A data into shared memory.\n"
79-
" for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
77+
<< " let a_col_start = tile * " << a_length_per_tile << ";\n"
78+
<< " // load one tile A data into shared memory.\n"
79+
<< " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
8080
<< " let a_col = a_col_start + a_offset;\n"
8181
" if (a_col < uniforms.input_a_shape[2]) {\n"
82-
" sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n"
83-
" } else {\n"
82+
<< " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n"
83+
<< " } else {\n"
8484
" sub_a[a_offset] = input_a_value_t(0);\n"
8585
" }\n"
8686
" }\n"
@@ -97,14 +97,14 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
9797
" let zero_point_nibble_offset: u32 = block & 0x1u;\n"
9898
" let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n"
9999
<< " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n"
100-
" let zero_point = output_element_t((zero_point_word) & 0xFu);\n";
100+
<< " let zero_point = output_element_t((zero_point_word) & 0xFu);\n";
101101
} else {
102102
shader.MainFunctionBody() << " // The default zero point is 8 for unsigned 4-bit quantization.\n"
103103
" let zero_point = output_element_t(8.0);\n";
104104
}
105105
shader.MainFunctionBody() << " let scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
106-
" let b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
107-
" var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
106+
<< " let b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
107+
<< " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
108108
<< " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n";
109109
switch (a.NumComponents()) {
110110
case 1:
@@ -139,18 +139,18 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
139139
}
140140
shader.MainFunctionBody() << ")) * scale;\n"
141141
" inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n"
142-
" word_offset += " << 8 / a.NumComponents() << ";\n"
143-
" }\n"
142+
<< " word_offset += " << 8 / a.NumComponents() << ";\n"
143+
<< " }\n"
144144
" workgroupBarrier();\n"
145145
" }\n"
146-
" if (local_idx < " << WorkgroupSizeY() << ") {\n"
147-
" var output_value = output_value_t(0);\n"
148-
" for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
149-
" output_value += inter_results[local_idx][b];\n"
146+
<< " if (local_idx < " << WorkgroupSizeY() << ") {\n"
147+
<< " var output_value = output_value_t(0);\n"
148+
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
149+
<< " output_value += inter_results[local_idx][b];\n"
150150
" }\n"
151151
" if (col + local_idx < uniforms.output_shape[2]) {\n"
152-
" " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
153-
" }\n"
152+
<< " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
153+
<< " }\n"
154154
" }\n";
155155
} else {
156156
const std::string quantized_data_type = QuantizedDataType(a.NumComponents());
@@ -359,15 +359,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
359359
if (use_block32) {
360360
components = 1;
361361
const uint32_t workgroup_size = 128;
362-
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
362+
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
363+
: 1;
363364
const uint32_t workgroup_x = workgroup_size / workgroup_y;
364365
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
365366
program.SetDispatchGroupSize(data_size / components / workgroup_y);
366367
} else {
367368
program.SetDispatchGroupSize(data_size / components / output_number);
368369
}
369370

370-
371371
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
372372
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
373373
TensorShape reshaped_y_shape{batch_count, M, N / components};

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ using namespace onnxruntime::webgpu;
1515
class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
1616
public:
1717
MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"},
18-
output_number_{output_number},
19-
components_b_{components_b},
20-
has_zero_points_{has_zero_points},
18+
output_number_{output_number},
19+
components_b_{components_b},
20+
has_zero_points_{has_zero_points},
2121
use_block32_{use_block32} {
2222
}
2323

0 commit comments

Comments
 (0)