@@ -266,7 +266,6 @@ export const createMatMulNBitsProgramInfo = (
266
266
};
267
267
};
268
268
269
- // TODO: support zeroPoints as input
270
269
// Currently, only support blockSize = 32.
271
270
export const createMatMulNBitsBlockSize32ProgramInfo = (
272
271
inputs: readonly TensorView[],
@@ -284,16 +283,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
284
283
const dataType = inputs[0].dataType;
285
284
const aComponents = getMaxComponents(attributes.k);
286
285
const bComponents = getMaxComponents(blobSizeInWords);
287
- // const components = getMaxComponents(dimBOuter);
288
- const components = 1;
289
286
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
290
287
291
288
const workgroupSize = 128;
292
- const workgroupY = 8 ;
289
+ const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1 ;
293
290
const workgroupX = workgroupSize / workgroupY;
294
291
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
295
292
const aLengthPerTile = tileSize / aComponents;
296
- const blocksPerTile = tileSize / attributes.blockSize; // This requires tileSize must be larger than or equal to blockSize.
293
+ const blocksPerTile = tileSize / attributes.blockSize;
294
+ const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
297
295
298
296
const programUniforms: ProgramUniform[] = [];
299
297
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
@@ -302,7 +300,10 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
302
300
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
303
301
programUniforms.push(...createTensorShapeVariables(bShape));
304
302
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
305
- const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
303
+ if (inputs.length === 4) {
304
+ programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
305
+ }
306
+ const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
306
307
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
307
308
308
309
const getShaderSource = (shaderHelper: ShaderHelper) => {
@@ -311,10 +312,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
311
312
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
312
313
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
313
314
const inputVariables = [a, b, scales];
315
+ const zeroPoints =
316
+ inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
317
+ if (zeroPoints) {
318
+ inputVariables.push(zeroPoints);
319
+ }
314
320
const outputRank = outputShapeTemp.length;
315
- const output = outputVariable('output', inputs[0].dataType, outputRank, components );
321
+ const output = outputVariable('output', inputs[0].dataType, outputRank);
316
322
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
317
- const readA = (( ) => {
323
+ const readA = () => {
318
324
switch (aComponents) {
319
325
case 1:
320
326
return `
@@ -331,66 +337,19 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
331
337
default:
332
338
throw new Error(`${aComponents}-component is not supported.`);
333
339
}
334
- });
335
-
336
- const processOneWord = (): string => {
337
- let calcStr = readA();
338
- for (let c = 0; c < components; c++) {
339
- calcStr += `
340
- b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
341
- b_value_lower = unpack4xU8(b_value & b_mask);
342
- b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
343
- b_quantized_values = mat2x4<${dataType}>(${Array.from(
344
- { length: 4 },
345
- (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
346
- ).join(', ')});
347
- b_dequantized_values = ${(() => {
348
- return `(b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale${c};`;
349
- })()};
350
- inter_results[local_id.y][local_id.x] += ${Array.from(
351
- { length: 2 },
352
- (_, i) =>
353
- `${
354
- `dot(a_data${i}, b_dequantized_values[${i}])`
355
- }`,
356
- ).join(' + ')};
357
- `;
358
- }
359
- return calcStr;
360
340
};
361
341
362
- const prepareScaleAndBData = (): string => {
363
- let calcStr = `var col_index = col * ${components};`;
364
- for (let c = 0; c < components; c++) {
365
- calcStr += `
366
- let b_row = workgroup_id.x * ${workgroupY} + local_id.y;
367
- let block = tile * ${blocksPerTile} + local_id.x;
368
- let scale${c} = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
369
- let b${c}_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
370
- col_index += 1;`;
371
- }
372
- calcStr += `
373
- var b_value: u32;
374
- let b_mask: u32 = 0x0F0F0F0Fu;
375
- var b_value_lower: vec4<u32>;
376
- var b_value_upper: vec4<u32>;
377
- var b_quantized_values:mat2x4<${dataType}>;
378
- var b_dequantized_values: mat2x4<${dataType}>;`;
379
- return calcStr;
380
- };
381
342
return `
382
343
var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
383
344
var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
384
345
${shaderHelper.declareVariables(...inputVariables, output)}
385
346
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
386
- let col = workgroup_id.x * ${workgroupY} + local_idx;
387
- let row = workgroup_id.y;
388
- let batch = workgroup_id.z;
347
+ let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)};
348
+ let col = output_indices[2];
349
+ let row = output_indices[1];
350
+ let batch = output_indices[0];
389
351
let n_blocks_per_col = uniforms.b_shape[1];
390
352
let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1;
391
- let blob_size_in_words = uniforms.b_shape[2];
392
- // The default zero point is 8 for unsigned 4-bit quantization.
393
- let zero_point = ${dataType}(${8.0});
394
353
395
354
// Loop over shared dimension.
396
355
for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
@@ -409,10 +368,40 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
409
368
workgroupBarrier();
410
369
411
370
// each thread process one block
371
+ let b_row = col + local_id.y;
372
+ let block = tile * ${blocksPerTile} + local_id.x;
373
+ ${
374
+ zeroPoints
375
+ ? `
376
+ let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
377
+ let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
378
+ let zero_point_word_index = zero_point_byte_count >> 0x2u;
379
+ let zero_point_byte_offset = zero_point_byte_count & 0x3u;
380
+ let zero_point_nibble_offset: u32 = block & 0x1u;
381
+ let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
382
+ let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
383
+ let zero_point = ${dataType}((zero_point_word) & 0xFu);`
384
+ : `
385
+ // The default zero point is 8 for unsigned 4-bit quantization.
386
+ let zero_point = ${dataType}(${8.0});`
387
+ }
388
+ let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
389
+ let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
412
390
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
413
- ${prepareScaleAndBData()}
414
391
for (var i: u32 = 0; i < ${bComponents}; i++) {
415
- ${processOneWord()}
392
+ ${readA()}
393
+ let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
394
+ let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
395
+ let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
396
+ let b_quantized_values = mat2x4<${dataType}>(${Array.from(
397
+ { length: 4 },
398
+ (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
399
+ ).join(', ')});
400
+ let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
401
+ inter_results[local_id.y][local_id.x] += ${Array.from(
402
+ { length: 2 },
403
+ (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
404
+ ).join(' + ')};
416
405
word_offset += ${8 / aComponents};
417
406
}
418
407
workgroupBarrier();
@@ -423,22 +412,22 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
423
412
for (var b = 0u; b < ${workgroupX}; b++) {
424
413
output_value += inter_results[local_idx][b];
425
414
}
426
- if (col < uniforms.output_shape[2])
415
+ if (col + local_idx < uniforms.output_shape[2])
427
416
{
428
- ${output.setByIndices(`${output.type.indices}(batch, row, col)`, 'output_value')}
417
+ ${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx )`, 'output_value')}
429
418
}
430
419
}
431
420
}`;
432
421
};
433
422
return {
434
423
name: 'BlockwiseMatMulNBits32',
435
424
shaderCache: {
436
- hint: `${attributes.blockSize};${attributes.bits };${aComponents };${bComponents };${components }`,
425
+ hint: `${attributes.blockSize};${aComponents };${bComponents };${workgroupX };${workgroupY }`,
437
426
inputDependencies: Array(inputs.length).fill('rank'),
438
427
},
439
428
getRunData: () => ({
440
429
outputs: [{ dims: outputShape, dataType }],
441
- dispatchGroup: { x: Math.ceil(dimBOuter / components / workgroupY), y: dimAOuter, z: batchSize },
430
+ dispatchGroup: { x: dispatchSize },
442
431
programUniforms,
443
432
}),
444
433
getShaderSource,
@@ -447,7 +436,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
447
436
448
437
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
449
438
validateInputs(context.inputs, attributes);
450
- if(context.inputs.length < 4 && attributes.blockSize == 32 && context.adapterInfo.isVendor(" intel" )) {
439
+ if ( attributes.blockSize === 32 && context.adapterInfo.isVendor(' intel' )) {
451
440
context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
452
441
} else {
453
442
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
0 commit comments