Skip to content

Commit 89201c3

Browse files
committed
clean code
1 parent ed571b6 commit 89201c3

File tree

1 file changed

+55
-66
lines changed

1 file changed

+55
-66
lines changed

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

+55-66
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ export const createMatMulNBitsProgramInfo = (
266266
};
267267
};
268268

269-
// TODO: support zeroPoints as input
270269
// Currently, only support blockSize = 32.
271270
export const createMatMulNBitsBlockSize32ProgramInfo = (
272271
inputs: readonly TensorView[],
@@ -284,16 +283,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
284283
const dataType = inputs[0].dataType;
285284
const aComponents = getMaxComponents(attributes.k);
286285
const bComponents = getMaxComponents(blobSizeInWords);
287-
// const components = getMaxComponents(dimBOuter);
288-
const components = 1;
289286
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
290287

291288
const workgroupSize = 128;
292-
const workgroupY = 8;
289+
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
293290
const workgroupX = workgroupSize / workgroupY;
294291
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
295292
const aLengthPerTile = tileSize / aComponents;
296-
const blocksPerTile = tileSize / attributes.blockSize; // This requires tileSize must be larger than or equal to blockSize.
293+
const blocksPerTile = tileSize / attributes.blockSize;
294+
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
297295

298296
const programUniforms: ProgramUniform[] = [];
299297
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
@@ -302,7 +300,10 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
302300
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
303301
programUniforms.push(...createTensorShapeVariables(bShape));
304302
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
305-
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
303+
if (inputs.length === 4) {
304+
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
305+
}
306+
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
306307
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
307308

308309
const getShaderSource = (shaderHelper: ShaderHelper) => {
@@ -311,10 +312,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
311312
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
312313
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
313314
const inputVariables = [a, b, scales];
315+
const zeroPoints =
316+
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
317+
if (zeroPoints) {
318+
inputVariables.push(zeroPoints);
319+
}
314320
const outputRank = outputShapeTemp.length;
315-
const output = outputVariable('output', inputs[0].dataType, outputRank, components);
321+
const output = outputVariable('output', inputs[0].dataType, outputRank);
316322
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
317-
const readA = (() => {
323+
const readA = () => {
318324
switch (aComponents) {
319325
case 1:
320326
return `
@@ -331,66 +337,19 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
331337
default:
332338
throw new Error(`${aComponents}-component is not supported.`);
333339
}
334-
});
335-
336-
const processOneWord = (): string => {
337-
let calcStr = readA();
338-
for (let c = 0; c < components; c++) {
339-
calcStr += `
340-
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
341-
b_value_lower = unpack4xU8(b_value & b_mask);
342-
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
343-
b_quantized_values = mat2x4<${dataType}>(${Array.from(
344-
{ length: 4 },
345-
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
346-
).join(', ')});
347-
b_dequantized_values = ${(() => {
348-
return `(b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale${c};`;
349-
})()};
350-
inter_results[local_id.y][local_id.x] += ${Array.from(
351-
{ length: 2 },
352-
(_, i) =>
353-
`${
354-
`dot(a_data${i}, b_dequantized_values[${i}])`
355-
}`,
356-
).join(' + ')};
357-
`;
358-
}
359-
return calcStr;
360340
};
361341

362-
const prepareScaleAndBData = (): string => {
363-
let calcStr = `var col_index = col * ${components};`;
364-
for (let c = 0; c < components; c++) {
365-
calcStr += `
366-
let b_row = workgroup_id.x * ${workgroupY} + local_id.y;
367-
let block = tile * ${blocksPerTile} + local_id.x;
368-
let scale${c} = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
369-
let b${c}_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
370-
col_index += 1;`;
371-
}
372-
calcStr += `
373-
var b_value: u32;
374-
let b_mask: u32 = 0x0F0F0F0Fu;
375-
var b_value_lower: vec4<u32>;
376-
var b_value_upper: vec4<u32>;
377-
var b_quantized_values:mat2x4<${dataType}>;
378-
var b_dequantized_values: mat2x4<${dataType}>;`;
379-
return calcStr;
380-
};
381342
return `
382343
var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
383344
var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
384345
${shaderHelper.declareVariables(...inputVariables, output)}
385346
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
386-
let col = workgroup_id.x * ${workgroupY} + local_idx;
387-
let row = workgroup_id.y;
388-
let batch = workgroup_id.z;
347+
let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)};
348+
let col = output_indices[2];
349+
let row = output_indices[1];
350+
let batch = output_indices[0];
389351
let n_blocks_per_col = uniforms.b_shape[1];
390352
let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1;
391-
let blob_size_in_words = uniforms.b_shape[2];
392-
// The default zero point is 8 for unsigned 4-bit quantization.
393-
let zero_point = ${dataType}(${8.0});
394353

395354
// Loop over shared dimension.
396355
for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
@@ -409,10 +368,40 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
409368
workgroupBarrier();
410369

411370
// each thread process one block
371+
let b_row = col + local_id.y;
372+
let block = tile * ${blocksPerTile} + local_id.x;
373+
${
374+
zeroPoints
375+
? `
376+
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
377+
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
378+
let zero_point_word_index = zero_point_byte_count >> 0x2u;
379+
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
380+
let zero_point_nibble_offset: u32 = block & 0x1u;
381+
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
382+
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
383+
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
384+
: `
385+
// The default zero point is 8 for unsigned 4-bit quantization.
386+
let zero_point = ${dataType}(${8.0});`
387+
}
388+
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
389+
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
412390
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
413-
${prepareScaleAndBData()}
414391
for (var i: u32 = 0; i < ${bComponents}; i++) {
415-
${processOneWord()}
392+
${readA()}
393+
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
394+
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
395+
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
396+
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
397+
{ length: 4 },
398+
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
399+
).join(', ')});
400+
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
401+
inter_results[local_id.y][local_id.x] += ${Array.from(
402+
{ length: 2 },
403+
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
404+
).join(' + ')};
416405
word_offset += ${8 / aComponents};
417406
}
418407
workgroupBarrier();
@@ -423,22 +412,22 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
423412
for (var b = 0u; b < ${workgroupX}; b++) {
424413
output_value += inter_results[local_idx][b];
425414
}
426-
if (col < uniforms.output_shape[2])
415+
if (col + local_idx < uniforms.output_shape[2])
427416
{
428-
${output.setByIndices(`${output.type.indices}(batch, row, col)`, 'output_value')}
417+
${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')}
429418
}
430419
}
431420
}`;
432421
};
433422
return {
434423
name: 'BlockwiseMatMulNBits32',
435424
shaderCache: {
436-
hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components}`,
425+
hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`,
437426
inputDependencies: Array(inputs.length).fill('rank'),
438427
},
439428
getRunData: () => ({
440429
outputs: [{ dims: outputShape, dataType }],
441-
dispatchGroup: { x: Math.ceil(dimBOuter / components / workgroupY), y: dimAOuter, z: batchSize },
430+
dispatchGroup: { x: dispatchSize },
442431
programUniforms,
443432
}),
444433
getShaderSource,
@@ -447,7 +436,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
447436

448437
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
449438
validateInputs(context.inputs, attributes);
450-
if(context.inputs.length < 4 && attributes.blockSize == 32 && context.adapterInfo.isVendor("intel")) {
439+
if (attributes.blockSize === 32 && context.adapterInfo.isVendor('intel')) {
451440
context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
452441
} else {
453442
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));

0 commit comments

Comments
 (0)