Skip to content

Commit

Permalink
[js/webgpu] Optimize InstanceNorm in some shapes
Browse files Browse the repository at this point in the history
BUG microsoft#22031

Optimize below two situations:
1. Increase workgroupSize if only one workgroup is dispatched.
2. Avoid transpose if not necessary.

The overall time of demucs model becomes 106.36 ms from 154.60 ms on my
dGPUs with this PR and PR microsoft#22577
  • Loading branch information
qjia7 committed Oct 29, 2024
1 parent 742594c commit 508cb95
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ const computeChannelScaleShift = (
const f32Type = components === 1 ? 'f32' : `vec${components}f`;
const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const unitsOfWork = n * c;

let workgroupSize = 64;
if (unitsOfWork === 1) {
workgroupSize = 256;
}
const inputShape = [n, c, h / components];
const outputShape = [n, c, 2];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
Expand All @@ -49,7 +52,6 @@ const computeChannelScaleShift = (
const b = inputVariable('bias', bias.dataType, bias.dims);
const output = outputVariable('output', DataType.float, 3, 2);
const variables = [x, s, b, output];
const workgroupSize = 64;
return `
var<workgroup> workgroup_shared : array<${wgType}, ${workgroupSize}>;
const workgroup_size = ${workgroupSize}u;
Expand Down Expand Up @@ -91,7 +93,7 @@ const computeChannelScaleShift = (
{
name: 'InstanceNormComputeChannelScaleShift',
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: { hint: `${components};${epsilon}`, inputDependencies },
shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: DataType.float }],
dispatchGroup: { x: unitsOfWork },
Expand Down Expand Up @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = (
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];

// 1. transpose x from NHWC to NCHW
let needTranspose = false;
const transposedXPerm = [0, xShape.length - 1];
for (let i = 0; i < xShape.length - 2; i++) {
needTranspose = needTranspose || xShape[i + 1] !== 1;
transposedXPerm.push(i + 1);
}
const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0];

needTranspose = needTranspose && xShape[xShape.length - 1] !== 1;

const transposedX = needTranspose
? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0]
: context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]]));
// 2. compute channel scale and channel shift.
const channelScaleShift = computeChannelScaleShift(
context,
Expand Down

0 comments on commit 508cb95

Please sign in to comment.