Skip to content

Commit

Permalink
[JS/WebGPU] Avoid producing presentKey/presentValue outputs if pastKe…
Browse files Browse the repository at this point in the history
…y/pastValue … (#21782)

Avoid producing presentKey/presentValue outputs if pastKey/pastValue
don't exists.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Aug 20, 2024
1 parent a22cc07 commit 1fb2e71
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 36 deletions.
77 changes: 42 additions & 35 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';

import {
Expand Down Expand Up @@ -257,7 +258,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
};
};

const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: TensorView, n: number, d: number) => {
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
const components = getMaxComponents(d);
let WG = 64;
const dComp = d / components;
Expand Down Expand Up @@ -358,7 +359,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
};

const createAttentionProbsProgramInfo = (
context: ComputeContext,
outputCount: number,
q: TensorView,
key: TensorView,
pastKey: TensorView | undefined,
Expand All @@ -369,7 +370,7 @@ const createAttentionProbsProgramInfo = (
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
const presentKeyShape = presentKey
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
: undefined;
Expand All @@ -394,9 +395,10 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
];

// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (pastKey) {
if (feedPastKey) {
inputDependencies.push('type');
}
if (attentionBias) {
Expand All @@ -410,7 +412,7 @@ const createAttentionProbsProgramInfo = (
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
if (pastKey) {
if (feedPastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
}
Expand Down Expand Up @@ -446,7 +448,7 @@ const createAttentionProbsProgramInfo = (
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
${(() => {
if (pastKey && presentKey) {
if (feedPastKey && presentKey) {
return `
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
Expand All @@ -464,7 +466,7 @@ const createAttentionProbsProgramInfo = (
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastKey && presentKey) {
if (feedPastKey && presentKey) {
return `
if (n + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
Expand Down Expand Up @@ -513,7 +515,7 @@ const createAttentionProbsProgramInfo = (
return {
name: 'AttentionProbs',
shaderCache: {
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
inputDependencies,
},
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
Expand All @@ -522,7 +524,7 @@ const createAttentionProbsProgramInfo = (
};

const createVxAttentionScoreProgramInfo = (
context: ComputeContext,
outputCount: number,
probs: TensorView,
v: TensorView,
pastValue: TensorView | undefined,
Expand All @@ -532,7 +534,7 @@ const createVxAttentionScoreProgramInfo = (
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
const presentValueShape = presentValue
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
: undefined;
Expand All @@ -553,7 +555,12 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
];
const inputDependencies: ProgramInputTensorInfoDependency[] = pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (feedPastValue) {
inputDependencies.push('type');
}
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
if (presentValue) {
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
Expand All @@ -562,7 +569,7 @@ const createVxAttentionScoreProgramInfo = (
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
const vHelper = inputVariable('v', v.dataType, v.dims);
const inputVars = [probsHelper, vHelper];
if (pastValue) {
if (feedPastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
const output = outputVariable('output', probs.dataType, outputShape);
Expand Down Expand Up @@ -591,7 +598,7 @@ const createVxAttentionScoreProgramInfo = (
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
${(() => {
if (pastValue && presentValue) {
if (feedPastValue && presentValue) {
return `
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
Expand All @@ -611,7 +618,7 @@ const createVxAttentionScoreProgramInfo = (
if (n < uniforms.N && w + local_id.y < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastValue && presentValue) {
if (feedPastValue && presentValue) {
return `
if (w + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
Expand Down Expand Up @@ -647,7 +654,7 @@ const createVxAttentionScoreProgramInfo = (

return {
name: 'AttentionScore',
shaderCache: { hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies },
shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
getShaderSource,
};
Expand All @@ -662,26 +669,32 @@ export const applyAttention = (
_past: TensorView | undefined,
pastKey: TensorView | undefined,
pastValue: TensorView | undefined,
attentionBias: TensorView | undefined,
attentionBiasInput: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const outputCount = context.outputCount;
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;

const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k];
const inputsK = [q, k];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
if (attentionBias) {
inputsK.push(attentionBias);
}

// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
context,
outputCount,
q,
k,
outputCount > 1 ? pastKey : undefined,
pastKey,
attentionBias,
parameters,
attributes,
Expand All @@ -693,7 +706,6 @@ export const applyAttention = (
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
context,
probs,
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
totalSequenceLength,
Expand All @@ -702,19 +714,14 @@ export const applyAttention = (
);

// Run AttrionScore
const inputsV =
parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v];
context.compute(
createVxAttentionScoreProgramInfo(
context,
probs,
v,
outputCount > 1 && pastValue ? pastValue : undefined,
parameters,
pastSequenceLength,
),
{ inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] },
);
const inputsV = [probs, v];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
inputs: inputsV,
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
});
};

const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from '
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';

const getInput = (inputs: readonly TensorView[], i: number) =>
inputs.length > i && inputs[i].dims.length > 0 && ShapeUtil.size(inputs[i].dims) > 0 ? inputs[i] : undefined;
inputs.length > i && inputs[i].dims.length > 0 ? inputs[i] : undefined;

const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
const query = inputs[0];
Expand Down
75 changes: 75 additions & 0 deletions js/web/test/data/ops/multihead-attention.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,5 +1073,80 @@
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with empty pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// AttentionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [],
"dims": [1, 1, 0, 1],
"type": "float32"
},
// PastValue
{
"data": [],
"dims": [1, 1, 0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
{
"data": [2],
"dims": [1, 1, 1, 1],
"type": "float32"
},
{
"data": [3],
"dims": [1, 1, 1, 1],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 1fb2e71

Please sign in to comment.