Skip to content

Commit e494758

Browse files
committed
[js/webgpu] Optimize Gemm
BUG microsoft#22031 The total Gemm time in demucs model becomes 181.14 ms from over 1000 ms on my iGPUs.
1 parent 2e4e221 commit e494758

File tree

1 file changed

+160
-1
lines changed
  • js/web/lib/wasm/jsep/webgpu/ops

1 file changed

+160
-1
lines changed

js/web/lib/wasm/jsep/webgpu/ops/gemm.ts

+160-1
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,15 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
5555
if (!outputShape) {
5656
throw new Error("Can't use gemm on the given tensors");
5757
}
58+
const tileSize = 16;
59+
const numTileN = Math.ceil(N / tileSize);
60+
const numTileM = Math.ceil(M / tileSize);
61+
// TODO: Find the condition when directly use the naive one.
62+
const useShared = true;
63+
5864
const outputSize = ShapeUtil.size(outputShape);
5965
const programUniforms: ProgramUniform[] = [
60-
{ type: DataType.uint32, data: outputSize },
66+
{ type: DataType.uint32, data: useShared ? numTileN : outputSize },
6167
{ type: DataType.uint32, data: M },
6268
{ type: DataType.uint32, data: N },
6369
{ type: DataType.uint32, data: K },
@@ -130,6 +136,159 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
130136
}`;
131137
};
132138

139+
const getShaderSourceShared = (shaderHelper: ShaderHelper) => {
140+
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
141+
const b = inputVariable('b', inputs[1].dataType, inputs[1].dims);
142+
let c: IndicesHelper | null = null;
143+
const variables = [a, b];
144+
if (inputs.length === 3) {
145+
c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length);
146+
variables.push(c);
147+
}
148+
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
149+
variables.push(output);
150+
const uniforms: UniformsArrayType = [
151+
{ name: 'num_tile_n', type: 'u32' },
152+
{ name: 'M', type: 'u32' },
153+
{ name: 'N', type: 'u32' },
154+
{ name: 'K', type: 'u32' },
155+
{ name: 'alpha', type: 'f32' },
156+
{ name: 'beta', type: 'f32' },
157+
];
158+
159+
let calcResult = '';
160+
let fillWorkgroupMemory = '';
161+
if (attributes.transA && attributes.transB) {
162+
fillWorkgroupMemory = `
163+
var col = tile_row_start + local_id.x;
164+
var row = k_start + local_id.y;
165+
if (col < uniforms.M && row < uniforms.K) {
166+
tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
167+
} else {
168+
tile_a[local_id.y][local_id.x] = ${a.type.value}(0);
169+
}
170+
171+
col = k_start + local_id.x;
172+
row = tile_col_start + local_id.y;
173+
if (col < uniforms.K && row < uniforms.N) {
174+
tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
175+
} else {
176+
tile_b[local_id.y][local_id.x] = ${b.type.value}(0);
177+
}
178+
`;
179+
calcResult = `value += tile_a[k][local_id.y] * tile_b[local_id.x][k];`;
180+
} else if (attributes.transA && !attributes.transB) {
181+
fillWorkgroupMemory = `
182+
var col = tile_row_start + local_id.x;
183+
var row = k_start + local_id.y;
184+
if (col < uniforms.M && row < uniforms.K) {
185+
tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
186+
} else {
187+
tile_a[local_id.y][local_id.x] = ${a.type.value}(0);
188+
}
189+
190+
col = tile_col_start + local_id.x;
191+
row = k_start + local_id.y;
192+
if (col < uniforms.N && row < uniforms.K) {
193+
tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
194+
} else {
195+
tile_b[local_id.y][local_id.x] = ${b.type.value}(0);
196+
}
197+
`;
198+
calcResult = `value += tile_a[k][local_id.y] * tile_b[k][local_id.x];`;
199+
} else if (!attributes.transA && attributes.transB) {
200+
fillWorkgroupMemory = `
201+
var col = k_start + local_id.x;
202+
var row = tile_row_start + local_id.y;
203+
if (col < uniforms.K && row < uniforms.M) {
204+
tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
205+
} else {
206+
tile_a[local_id.y][local_id.x] = ${a.type.value}(0);
207+
}
208+
209+
col = k_start + local_id.x;
210+
row = tile_col_start + local_id.y;
211+
if (col < uniforms.K && row < uniforms.N) {
212+
tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
213+
} else {
214+
tile_b[local_id.y][local_id.x] = ${b.type.value}(0);
215+
}
216+
`;
217+
calcResult = `value += tile_a[local_id.y][k] * tile_b[local_id.x][k];`;
218+
} else if (!attributes.transA && !attributes.transB) {
219+
fillWorkgroupMemory = `
220+
var col = k_start + local_id.x;
221+
var row = tile_row_start + local_id.y;
222+
if (col < uniforms.K && row < uniforms.M) {
223+
tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
224+
} else {
225+
tile_a[local_id.y][local_id.x] = ${a.type.value}(0);
226+
}
227+
228+
col = tile_col_start + local_id.x;
229+
row = k_start + local_id.y;
230+
if (col < uniforms.N && row < uniforms.K) {
231+
tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
232+
} else {
233+
tile_b[local_id.y][local_id.x] = ${b.type.value}(0);
234+
}
235+
`;
236+
calcResult = `value += tile_a[local_id.y][k] * tile_b[k][local_id.x];`;
237+
}
238+
239+
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;';
240+
241+
return `
242+
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
243+
var<workgroup> tile_a: array<array<${a.type.storage}, ${tileSize}>, ${tileSize}>;
244+
var<workgroup> tile_b: array<array<${b.type.storage}, ${tileSize}>, ${tileSize}>;
245+
${shaderHelper.mainStart([tileSize, tileSize, 1])}
246+
let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${tileSize};
247+
let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${tileSize};
248+
let num_tiles = (uniforms.K - 1) / ${tileSize} + 1;
249+
var k_start = 0u;
250+
var value = ${output.type.value}(0);
251+
for (var t: u32 = 0u; t < num_tiles; t++) {
252+
${fillWorkgroupMemory}
253+
k_start = k_start + ${tileSize};
254+
workgroupBarrier();
255+
256+
for (var k: u32 = 0u; k < ${tileSize}; k++) {
257+
${calcResult}
258+
}
259+
workgroupBarrier();
260+
}
261+
262+
${calculateAlpha}
263+
let m = tile_row_start + local_id.y;
264+
let n = tile_col_start + local_id.x;
265+
${(() => {
266+
if (c != null) {
267+
return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${
268+
output.type.value
269+
}(uniforms.beta) * ${c.getByOffset('cOffset')};`;
270+
}
271+
return '';
272+
})()}
273+
if (m < uniforms.M && n < uniforms.N) {
274+
output[m * uniforms.N + n] = value;
275+
}
276+
}`;
277+
};
278+
279+
if (useShared) {
280+
return {
281+
name: 'GemmShared',
282+
shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies },
283+
getRunData: () => ({
284+
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
285+
dispatchGroup: { x: numTileN * numTileM },
286+
programUniforms,
287+
}),
288+
getShaderSource: getShaderSourceShared,
289+
};
290+
}
291+
133292
return {
134293
name: 'Gemm',
135294
shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies },

0 commit comments

Comments
 (0)