From 508cb957c7f29d2bd0b0703d74cbe85fa48b420b Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 28 Oct 2024 16:21:17 +0800 Subject: [PATCH] [js/webgpu] Optimize InstanceNorm in some shapes BUG #22031 Optimize below two situations: 1. Increase workgroupSize if only one workgroup is dispatched. 2. Avoid transpose if not necessary. The overall time of demucs model becomes 106.36 ms from 154.60 ms on my dGPUs with this PR and PR #22577 --- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 859bd850862aa..a357d29667319 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -36,7 +36,10 @@ const computeChannelScaleShift = ( const f32Type = components === 1 ? 'f32' : `vec${components}f`; const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; const unitsOfWork = n * c; - + let workgroupSize = 64; + if (unitsOfWork === 1) { + workgroupSize = 256; + } const inputShape = [n, c, h / components]; const outputShape = [n, c, 2]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; @@ -49,7 +52,6 @@ const computeChannelScaleShift = ( const b = inputVariable('bias', bias.dataType, bias.dims); const output = outputVariable('output', DataType.float, 3, 2); const variables = [x, s, b, output]; - const workgroupSize = 64; return ` var workgroup_shared : array<${wgType}, ${workgroupSize}>; const workgroup_size = ${workgroupSize}u; @@ -91,7 +93,7 @@ const computeChannelScaleShift = ( { name: 'InstanceNormComputeChannelScaleShift', // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: DataType.float }], dispatchGroup: { x: unitsOfWork }, @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = ( const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // 1. transpose x from NHWC to NCHW + let needTranspose = false; const transposedXPerm = [0, xShape.length - 1]; for (let i = 0; i < xShape.length - 2; i++) { + needTranspose = needTranspose || xShape[i + 1] !== 1; transposedXPerm.push(i + 1); } - const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { - inputs: [context.inputs[0]], - outputs: [-1], - })[0]; + + needTranspose = needTranspose && xShape[xShape.length - 1] !== 1; + + const transposedX = needTranspose + ? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0] + : context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]])); // 2. compute channel scale and channel shift. const channelScaleShift = computeChannelScaleShift( context,