Skip to content

Commit 717ac31

Browse files
committed
[js/webgpu] Optimize Expand
Use components = 4 if possible. llama3.2-1B becomes 20 tokens/s from 18 tokens/s on my iGPUs.
1 parent 742a0d3 commit 717ac31

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

js/web/lib/wasm/jsep/webgpu/ops/expand.ts

+13-5
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
4848
const shape = Array.from(inputs[1].getBigInt64Array(), Number);
4949
const outputShape: number[] = calculateOutputShape(inputShape, shape);
5050
const dataType = inputs[0].dataType;
51-
const components = dataType === DataType.bool ? 4 : 1;
51+
const isBoolOrScalar = dataType === DataType.bool || ShapeUtil.size(inputShape) === 1;
52+
const iComponents =
53+
dataType === DataType.bool ? 4 : inputShape.length > 0 && inputShape[inputShape.length - 1] % 4 === 0 ? 4 : 1;
54+
const components = isBoolOrScalar
55+
? 4
56+
: outputShape.length > 0 && outputShape[outputShape.length - 1] % 4 === 0
57+
? 4
58+
: 1;
5259
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
5360

5461
const getShaderSource = (shaderHelper: ShaderHelper) => {
55-
const input = inputVariable('input', dataType, inputShape.length, components);
62+
const input = inputVariable('input', dataType, inputShape.length, iComponents);
5663
const output = outputVariable('output', dataType, outputShape.length, components);
5764
let assignment: string;
5865
if (dataType === DataType.bool) {
@@ -74,9 +81,10 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
7481
}`;
7582
} else {
7683
assignment = `
77-
let outputIndices = ${output.offsetToIndices('global_idx')};
84+
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
7885
let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)};
79-
${output.setByOffset('global_idx', input.getByOffset('inputOffset'))}
86+
let data = ${output.type.value}(${input.getByOffset(`inputOffset / ${iComponents}`)});
87+
${output.setByOffset('global_idx', 'data')}
8088
}`;
8189
}
8290
return `
@@ -92,7 +100,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
92100
];
93101
return {
94102
name: 'Expand',
95-
shaderCache: { hint: `${outputShape.length}`, inputDependencies: ['rank'] },
103+
shaderCache: { hint: `${outputShape.length};${iComponents}${components}`, inputDependencies: ['rank'] },
96104
getShaderSource,
97105
getRunData: () => ({
98106
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],

js/web/test/data/ops/expand.jsonc

+50
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,56 @@
134134
"type": "float32"
135135
}
136136
]
137+
},
138+
{
139+
"name": "Expand in components = 1, out components = 4",
140+
"inputs": [
141+
{
142+
"data": [1, 2, 3, 4, 5, 6],
143+
"dims": [3, 2, 1],
144+
"type": "float32"
145+
},
146+
{
147+
"data": [3, 1, 8],
148+
"dims": [3],
149+
"type": "int64"
150+
}
151+
],
152+
"outputs": [
153+
{
154+
"data": [
155+
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5,
156+
5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6
157+
],
158+
"dims": [3, 2, 8],
159+
"type": "float32"
160+
}
161+
]
162+
},
163+
{
164+
"name": "Expand in components = 4, out components = 4",
165+
"inputs": [
166+
{
167+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
168+
"dims": [1, 1, 2, 8],
169+
"type": "float32"
170+
},
171+
{
172+
"data": [2, 1, 8],
173+
"dims": [3],
174+
"type": "int64"
175+
}
176+
],
177+
"outputs": [
178+
{
179+
"data": [
180+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
181+
16
182+
],
183+
"dims": [1, 2, 2, 8],
184+
"type": "float32"
185+
}
186+
]
137187
}
138188
]
139189
},

0 commit comments

Comments
 (0)