@@ -64,23 +64,23 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
64
64
const uint32_t block_size = 32 ;
65
65
const uint32_t blocks_per_tile = tile_size / block_size;
66
66
shader.AdditionalImplementation () << " var<workgroup> sub_a: array<input_a_value_t, " << a_length_per_tile << " >;\n "
67
- << " var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX () << " >, " << WorkgroupSizeY () << " >;\n " ;
67
+ << " var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX () << " >, " << WorkgroupSizeY () << " >;\n " ;
68
68
std::string offset = " workgroup_idx * " + std::to_string (WorkgroupSizeY ());
69
69
shader.MainFunctionBody () << " let output_indices = " << y.OffsetToIndices (offset) << " ;\n "
70
70
<< " let col = output_indices[2];\n "
71
71
" let row = output_indices[1];\n "
72
72
" let batch = output_indices[0];\n "
73
73
" let n_blocks_per_col = uniforms.input_b_shape[1];\n "
74
- " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n "
74
+ << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n "
75
75
<< " // Loop over shared dimension.\n "
76
76
" for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n "
77
- " let a_col_start = tile * " << a_length_per_tile << " ;\n "
78
- " // load one tile A data into shared memory.\n "
79
- " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << " ; a_offset += " << workgroup_size << " ) {\n "
77
+ << " let a_col_start = tile * " << a_length_per_tile << " ;\n "
78
+ << " // load one tile A data into shared memory.\n "
79
+ << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << " ; a_offset += " << workgroup_size << " ) {\n "
80
80
<< " let a_col = a_col_start + a_offset;\n "
81
81
" if (a_col < uniforms.input_a_shape[2]) {\n "
82
- " sub_a[a_offset] = " << a.GetByIndices (" input_a_indices_t(batch, row, a_col)" ) << " ;\n "
83
- " } else {\n "
82
+ << " sub_a[a_offset] = " << a.GetByIndices (" input_a_indices_t(batch, row, a_col)" ) << " ;\n "
83
+ << " } else {\n "
84
84
" sub_a[a_offset] = input_a_value_t(0);\n "
85
85
" }\n "
86
86
" }\n "
@@ -97,14 +97,14 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
97
97
" let zero_point_nibble_offset: u32 = block & 0x1u;\n "
98
98
" let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n "
99
99
<< " let zero_point_word = " << zero_points.GetByOffset (" zero_point_word_index" ) << " >> zero_point_bits_offset;\n "
100
- " let zero_point = output_element_t((zero_point_word) & 0xFu);\n " ;
100
+ << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n " ;
101
101
} else {
102
102
shader.MainFunctionBody () << " // The default zero point is 8 for unsigned 4-bit quantization.\n "
103
103
" let zero_point = output_element_t(8.0);\n " ;
104
104
}
105
105
shader.MainFunctionBody () << " let scale = " << scales.GetByOffset (" b_row * n_blocks_per_col + block" ) << " ;\n "
106
- " let b_data = " << b.GetByIndices (" input_b_indices_t(b_row, block, 0)" ) << " ;\n "
107
- " var word_offset = local_id.x * " << block_size / a.NumComponents () << " ;\n "
106
+ << " let b_data = " << b.GetByIndices (" input_b_indices_t(b_row, block, 0)" ) << " ;\n "
107
+ << " var word_offset = local_id.x * " << block_size / a.NumComponents () << " ;\n "
108
108
<< " for (var i: u32 = 0; i < " << components_b_ << " ; i++) {\n " ;
109
109
switch (a.NumComponents ()) {
110
110
case 1 :
@@ -139,18 +139,18 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
139
139
}
140
140
shader.MainFunctionBody () << " )) * scale;\n "
141
141
" inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n "
142
- " word_offset += " << 8 / a.NumComponents () << " ;\n "
143
- " }\n "
142
+ << " word_offset += " << 8 / a.NumComponents () << " ;\n "
143
+ << " }\n "
144
144
" workgroupBarrier();\n "
145
145
" }\n "
146
- " if (local_idx < " << WorkgroupSizeY () << " ) {\n "
147
- " var output_value = output_value_t(0);\n "
148
- " for (var b = 0u; b < " << WorkgroupSizeX () << " ; b++) {\n "
149
- " output_value += inter_results[local_idx][b];\n "
146
+ << " if (local_idx < " << WorkgroupSizeY () << " ) {\n "
147
+ << " var output_value = output_value_t(0);\n "
148
+ << " for (var b = 0u; b < " << WorkgroupSizeX () << " ; b++) {\n "
149
+ << " output_value += inter_results[local_idx][b];\n "
150
150
" }\n "
151
151
" if (col + local_idx < uniforms.output_shape[2]) {\n "
152
- " " << y.SetByIndices (" output_indices_t(batch, row, col + local_idx)" , " output_value" ) << " ;\n "
153
- " }\n "
152
+ << " " << y.SetByIndices (" output_indices_t(batch, row, col + local_idx)" , " output_value" ) << " ;\n "
153
+ << " }\n "
154
154
" }\n " ;
155
155
} else {
156
156
const std::string quantized_data_type = QuantizedDataType (a.NumComponents ());
@@ -359,15 +359,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
359
359
if (use_block32) {
360
360
components = 1 ;
361
361
const uint32_t workgroup_size = 128 ;
362
- const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1 ;
362
+ const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
363
+ : 1 ;
363
364
const uint32_t workgroup_x = workgroup_size / workgroup_y;
364
365
program.SetWorkgroupSize (workgroup_x, workgroup_y, 1 );
365
366
program.SetDispatchGroupSize (data_size / components / workgroup_y);
366
367
} else {
367
368
program.SetDispatchGroupSize (data_size / components / output_number);
368
369
}
369
370
370
-
371
371
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
372
372
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
373
373
TensorShape reshaped_y_shape{batch_count, M, N / components};
0 commit comments