Skip to content

Commit

Permalink
[js/webgpu] support FlashAttention-2 for attention operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Nov 22, 2024
1 parent e430795 commit c936365
Showing 1 changed file with 179 additions and 12 deletions.
191 changes: 179 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1004,21 +1004,188 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
);
};

// The algorithm refers the FlashAttention forward pass in the https://tridao.me/publications/flash2/flash2.pdf
const createFlashAttentionV2ProgramInfo = (
q: TensorView,
k: TensorView,
v: TensorView,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const components = 4;
const bR = 32;
const bC = bR;
const tR = q.dims[2] / bR;
const tC = k.dims[2] / bC;
const d = q.dims[3] / components;
const numTiles = Math.ceil(q.dims[3] / bC);
const workgroupSize: [number, number, number] = [8, 32, 1];
const qInner = numTiles * workgroupSize[0];
const colsPerThread = 4; // (Bc / workgroupSize[0])
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;

const dispatch = {
x: tR,
y: v.dims[1],
z: v.dims[0],
};

const headOffset = v.dims[2] * v.dims[3];
const batchOffset = v.dims[1] * headOffset;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: batchOffset },
{ type: DataType.uint32, data: headOffset },
{ type: DataType.uint32, data: v.dims[1] },
{ type: DataType.uint32, data: v.dims[3] },
{ type: DataType.uint32, data: d },
{ type: DataType.float, data: alpha },
];

const outputDims = [v.dims[0], v.dims[2], v.dims[1] * v.dims[3]];
const outputs = [{ dims: outputDims, dataType: q.dataType, gpuDataType: GpuDataType.default }];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('Q', q.dataType, q.dims, components);
const kInput = inputVariable('K', k.dataType, k.dims, components);
const vInput = inputVariable('V', k.dataType, k.dims, components);
const inputVars = [qInput, kInput, vInput];

const output = outputVariable('output', q.dataType, outputDims);
const outputVars = [output];
const type = tensorTypeToWsglValueType(v.dataType);

const uniforms: UniformsArrayType = [
{ name: 'batchOffset', type: 'u32' },
{ name: 'headOffset', type: 'u32' },
{ name: 'headNum', type: 'u32' },
{ name: 'headSize', type: 'u32' },
{ name: 'd', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
];

return `
var<workgroup> Q_i : array<array<${qInput.type.storage}, ${qInner}>, ${bR}>;
var<workgroup> KV_j : array<array<${qInput.type.storage}, ${workgroupSize[0]}>, ${bC}>;
var<workgroup> S_i_j : array<array<${output.type.storage}, ${bC}>, ${bR}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart(workgroupSize)}
let offset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + workgroup_id.x * ${bR} * uniforms.d;
var O_i : array<${type}, ${numTiles * colsPerThread}>;
var l_i_j = ${type}(0);
var m_i_j = ${type === 'f32' ? 'f32(-3.402823e+38f)' : 'f16(-65504)'};
for (var tile = 0; tile < ${numTiles}; tile++) {
Q_i[local_id.y][u32(${workgroupSize[0]} * tile) + local_id.x] = Q[offset + local_id.y * uniforms.d + u32(${workgroupSize[0]} * tile) + local_id.x];
}
for (var j = 0; j < ${tC}; j++) {
var acc : array<${type}, ${colsPerThread}>;
let kvOffset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + u32(j * ${bC}) * uniforms.d;
for (var tile = 0; tile < ${numTiles}; tile++) {
KV_j[local_id.y][local_id.x] = K[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)];
workgroupBarrier();
for (var col = 0; col < ${colsPerThread}; col++) {
for (var k = 0; k < ${workgroupSize[0]}; k++) {
acc[col] += dot(Q_i[local_id.y][k + tile * 8], KV_j[local_id.x + u32(col * 8)][k]);
}
}
workgroupBarrier();
}
for (var col = 0; col < ${colsPerThread}; col++) {
S_i_j[local_id.y][u32(col * 8) + local_id.x] = acc[col] * ${type}(uniforms.alpha);
}
workgroupBarrier();
let m_i_j_1 = m_i_j;
for (var m = 0; m < ${bC}; m++) {
m_i_j = max(m_i_j, S_i_j[local_id.y][m]);
}
let exp_j_j_1 = exp(m_i_j_1 - m_i_j);
l_i_j *= exp_j_j_1;
for (var o = 0; o < ${colsPerThread * numTiles}; o++) {
O_i[o] *= exp_j_j_1;
}
for (var tile = 0; tile < ${numTiles}; tile++) {
KV_j[local_id.y][local_id.x] = V[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)];
workgroupBarrier();
for (var d = 0; d < ${bC}; d++) {
let p_i_j = exp(S_i_j[local_id.y][d] - m_i_j);
if (tile == 0) {
l_i_j += p_i_j;
}
for (var col = 0; col < ${colsPerThread}; col++) {
let v_i_j = KV_j[d][(u32(8 * col) + local_id.x) / 4][(u32(8 * col) + local_id.x) % 4];
O_i[col * ${numTiles} + tile] += p_i_j * v_i_j;
}
}
workgroupBarrier();
}
}
let outputOffset = workgroup_id.z * uniforms.batchOffset + (workgroup_id.x * 32 + local_id.y) * uniforms.headNum
* uniforms.headSize + workgroup_id.y * uniforms.headSize + local_id.x;
for (var tile = 0; tile < ${numTiles}; tile++) {
for (var col = 0; col < ${colsPerThread}; col++) {
let outputIndx = outputOffset + u32(tile * ${bC}) + u32(col * 8);
output[outputIndx] = O_i[col * ${numTiles} + tile] / l_i_j;
}
}
}`;
};
return {
name: 'FlashAttentionV2',
shaderCache: { hint: `${numTiles}`, inputDependencies },
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
getShaderSource,
};
};

const applyFlashAttentionV2 = (
context: ComputeContext,
q: TensorView,
k: TensorView,
v: TensorView,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const inputs = [q, k, v];
context.compute(createFlashAttentionV2ProgramInfo(q, k, v, parameters, attributes), { inputs });
};

export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => {
const params = validateAttentionInputs(context.inputs, attributes);

const [q, k, v] = prepare(context, params);

return applyAttention(
context,
q,
k,
v,
context.inputs[4],
undefined,
undefined,
undefined,
context.inputs[5],
params,
);
if (
params.sequenceLength >= 1024 &&
params.sequenceLength % 32 === 0 &&
params.headSize <= 128 &&
params.headSize % 32 === 0 &&
context.inputs[4] === undefined &&
context.inputs[5] === undefined
) {
return applyFlashAttentionV2(context, q, k, v, params, attributes);
} else {
return applyAttention(
context,
q,
k,
v,
context.inputs[4],
undefined,
undefined,
undefined,
context.inputs[5],
params,
);
}
};

0 comments on commit c936365

Please sign in to comment.