diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a78c8ae3b190..dfbd9f4933523 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1004,21 +1004,188 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { ); }; +// The algorithm refers the FlashAttention forward pass in the https://tridao.me/publications/flash2/flash2.pdf +const createFlashAttentionV2ProgramInfo = ( + q: TensorView, + k: TensorView, + v: TensorView, + parameters: AttentionParameters, + attributes: AttentionAttrs, +) => { + const components = 4; + const Br = 32; + const Bc = Br; + const Tr = q.dims[2] / Br; + const Tc = k.dims[2] / Bc; + const d = q.dims[3] / components; + const numTiles = Math.ceil(q.dims[3] / Bc); + const workgroupSize: [number, number, number] = [8, 32, 1]; + const qInner = numTiles * workgroupSize[0]; + const colsPerThread = 4; // (Bc / workgroupSize[0]) + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dispatch = { + x: Tr, + y: v.dims[1], + z: v.dims[0], + }; + + const headOffset = v.dims[2] * v.dims[3]; + const batchOffset = v.dims[1] * headOffset; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: batchOffset }, + { type: DataType.uint32, data: headOffset }, + { type: DataType.uint32, data: v.dims[1] }, + { type: DataType.uint32, data: v.dims[3] }, + { type: DataType.uint32, data: d }, + { type: DataType.float, data: alpha }, + ]; + + const outputDims = [v.dims[0], v.dims[2], v.dims[1] * v.dims[3]]; + const outputs = [{ dims: outputDims, dataType: q.dataType, gpuDataType: GpuDataType.default }]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const qInput = inputVariable('Q', q.dataType, q.dims, components); + const kInput = inputVariable('K', k.dataType, k.dims, components); + const vInput = inputVariable('V', k.dataType, k.dims, components); + const inputVars = [qInput, kInput, vInput]; + + const output = outputVariable('output', q.dataType, outputDims); + const outputVars = [output]; + const type = tensorTypeToWsglValueType(v.dataType); + + const uniforms: UniformsArrayType = [ + { name: 'batchOffset', type: 'u32' }, + { name: 'headOffset', type: 'u32' }, + { name: 'headNum', type: 'u32' }, + { name: 'headSize', type: 'u32' }, + { name: 'd', type: 'u32' }, + { name: 'alpha', type: 'f32' as UniformDataElementType }, + ]; + + return ` + var Q_i : array, ${Br}>; + var KV_j : array, ${Bc}>; + var S_i_j : array, ${Br}>; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} + ${shaderHelper.mainStart(workgroupSize)} + let offset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + workgroup_id.x * ${Br} * uniforms.d; + + var O_i : array<${type}, ${numTiles * colsPerThread}>; + var l_i_j = ${type}(0); + var m_i_j = ${type === 'f32' ? 'f32(-3.402823e+38f)' : 'f16(-65504)'}; + + for (var tile = 0; tile < ${numTiles}; tile++) { + Q_i[local_id.y][u32(${workgroupSize[0]} * tile) + local_id.x] = Q[offset + local_id.y * uniforms.d + u32(${workgroupSize[0]} * tile) + local_id.x]; + } + + for (var j = 0; j < ${Tc}; j++) { + var acc : array<${type}, ${colsPerThread}>; + + let kvOffset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + u32(j * ${Bc}) * uniforms.d; + for (var tile = 0; tile < ${numTiles}; tile++) { + KV_j[local_id.y][local_id.x] = K[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)]; + workgroupBarrier(); + for (var col = 0; col < ${colsPerThread}; col++) { + for (var k = 0; k < ${workgroupSize[0]}; k++) { + acc[col] += dot(Q_i[local_id.y][k + tile * 8], KV_j[local_id.x + u32(col * 8)][k]); + } + } + workgroupBarrier(); + } + + for (var col = 0; col < ${colsPerThread}; col++) { + S_i_j[local_id.y][u32(col * 8) + local_id.x] = acc[col] * ${type}(uniforms.alpha); + } + + workgroupBarrier(); + let m_i_j_1 = m_i_j; + for (var m = 0; m < ${Bc}; m++) { + m_i_j = max(m_i_j, S_i_j[local_id.y][m]); + } + + let exp_j_j_1 = exp(m_i_j_1 - m_i_j); + l_i_j *= exp_j_j_1; + for (var o = 0; o < ${colsPerThread * numTiles}; o++) { + O_i[o] *= exp_j_j_1; + } + + for (var tile = 0; tile < ${numTiles}; tile++) { + KV_j[local_id.y][local_id.x] = V[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)]; + workgroupBarrier(); + + for (var d = 0; d < ${Bc}; d++) { + let p_i_j = exp(S_i_j[local_id.y][d] - m_i_j); + if (tile == 0) { + l_i_j += p_i_j; + } + + for (var col = 0; col < ${colsPerThread}; col++) { + let v_i_j = KV_j[d][(u32(8 * col) + local_id.x) / 4][(u32(8 * col) + local_id.x) % 4]; + O_i[col * ${numTiles} + tile] += p_i_j * v_i_j; + } + } + workgroupBarrier(); + } + } + + let outputOffset = workgroup_id.z * uniforms.batchOffset + (workgroup_id.x * 32 + local_id.y) * uniforms.headNum + * uniforms.headSize + workgroup_id.y * uniforms.headSize + local_id.x; + for (var tile = 0; tile < ${numTiles}; tile++) { + for (var col = 0; col < ${colsPerThread}; col++) { + let outputIndx = outputOffset + u32(tile * ${Bc}) + u32(col * 8); + output[outputIndx] = O_i[col * ${numTiles} + tile] / l_i_j; + } + } + }`; + }; + return { + name: 'FlashAttentionV2', + shaderCache: { hint: `${numTiles}`, inputDependencies }, + getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), + getShaderSource, + }; +}; + +const applyFlashAttentionV2 = ( + context: ComputeContext, + q: TensorView, + k: TensorView, + v: TensorView, + parameters: AttentionParameters, + attributes: AttentionAttrs, +) => { + const inputs = [q, k, v]; + context.compute(createFlashAttentionV2ProgramInfo(q, k, v, parameters, attributes), { inputs }); +}; + export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { const params = validateAttentionInputs(context.inputs, attributes); const [q, k, v] = prepare(context, params); - return applyAttention( - context, - q, - k, - v, - context.inputs[4], - undefined, - undefined, - undefined, - context.inputs[5], - params, - ); + if ( + params.sequenceLength >= 1024 && + params.sequenceLength % 32 === 0 && + params.headSize <= 128 && + params.headSize % 32 === 0 && + context.inputs[4] === undefined && + context.inputs[5] === undefined + ) { + return applyFlashAttentionV2(context, q, k, v, params, attributes); + } else { + return applyAttention( + context, + q, + k, + v, + context.inputs[4], + undefined, + undefined, + undefined, + context.inputs[5], + params, + ); + } };