Skip to content

Commit 335f67c

Browse files
committed
[js/webgpu] Optimize maybeTransposeToBNSHAndAddBias
With this optimization, 96 MultiHeadAttention|Transpose ops in phi3 disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs.
1 parent 3321735 commit 335f67c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts

+6
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = (
338338
if (input.dims.length === 3) {
339339
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
340340
}
341+
if (numHeads === 1 || sequenceLength === 1) {
342+
return reshapedInput;
343+
}
341344
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
342345
inputs: [reshapedInput],
343346
outputs: [-1],
@@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = (
356359
biasOffset!,
357360
);
358361
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
362+
if (numHeads === 1 || sequenceLength === 1) {
363+
return reshapedInput;
364+
}
359365
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
360366
inputs: [reshapedInput],
361367
outputs: [-1],

0 commit comments

Comments
 (0)