Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into webnn_lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 authored Aug 29, 2024
2 parents b76d3a9 + be76e1e commit 7dce2af
Show file tree
Hide file tree
Showing 28 changed files with 469 additions and 347 deletions.
59 changes: 55 additions & 4 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
}
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"

#include <cuda_runtime.h>
+#include <cuda_fp16.h>

#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
Expand All @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_con
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
// User are not expected to set the value to 0 as it is reserved for internal use.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";

// Specify the type of workload for this run.
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtRunOptionsWorkloadType = "run.workload_type";
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,8 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";

// Specify the type of workload for this session.
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtSessionOptionsWorkloadType = "session.workload_type";
16 changes: 8 additions & 8 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ class TensorViewImpl implements TensorView {
public readonly dims: readonly number[],
) {}

getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}

getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');
Expand Down Expand Up @@ -59,6 +51,14 @@ class TensorViewImpl implements TensorView {
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
}

getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}

reshape(newDims: readonly number[]): TensorView {
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
throw new Error('Invalid new shape');
Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/jsep/tensor-view.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ export interface TensorView {
*/
getInt32Array(): Int32Array;

/**
* get a Uint16Array data view of the tensor data. tensor data must be on CPU.
*/
getUint16Array(): Uint16Array;

/**
* create a new tensor view with the same data but different dimensions.
*/
Expand Down
112 changes: 79 additions & 33 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo } from '../types';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';

import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common';
import {
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglValueType,
UniformDataElementType,
UniformsArrayType,
} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand All @@ -20,6 +27,7 @@ const createElementwiseProgramShader = (
outputDataType: number,
funcCall: ElementwiseFunctionCall,
additionalImplementation?: string,
additionalUniformsType?: UniformsArrayType,
): string => {
const vecSize = Math.ceil(datasize / 4);

Expand All @@ -32,9 +40,13 @@ const createElementwiseProgramShader = (

const input = inputVariable('inputData', inputDataType, [vecSize], 4);
const output = outputVariable('outputData', outputDataType, [vecSize], 4);
const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }];
if (additionalUniformsType) {
uniforms.push(...additionalUniformsType);
}

return `
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${additionalImplementation ?? ''}
Expand All @@ -53,24 +65,38 @@ const createElementwiseProgramInfo = (
additionalImplementation?: string,
cacheKey?: string,
outputDataType: number = input.dataType,
): ProgramInfo => ({
name,
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
getShaderSource: (shaderHelper) =>
createElementwiseProgramShader(
shaderHelper,
ShapeUtil.size(input.dims),
input.dataType,
outputDataType,
funcCall,
additionalImplementation,
),
getRunData: (inputTensors) => ({
outputs: [{ dims: input.dims, dataType: outputDataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) },
programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }],
}),
});
additionalUniforms?: ProgramUniform[],
additionalUniformsType?: UniformsArrayType,
): ProgramInfo => {
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) },
];
if (additionalUniforms) {
programUniforms.push(...additionalUniforms);
}

return {
name,
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
getShaderSource: (shaderHelper) =>
createElementwiseProgramShader(
shaderHelper,
ShapeUtil.size(input.dims),
input.dataType,
outputDataType,
funcCall,
additionalImplementation,
additionalUniformsType,
),
getRunData: (inputTensors) => ({
outputs: [{ dims: input.dims, dataType: outputDataType }],
dispatchGroup: {
x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */),
},
programUniforms,
}),
};
};

export const abs = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs'));
Expand Down Expand Up @@ -139,24 +165,46 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
let min: number;
let max: number;
const hasMin = inputs.length >= 2 && inputs[1].data !== 0;
const hasMax = inputs.length >= 3 && inputs[2].data !== 0;

switch (inputs[0].dataType) {
case DataType.float:
min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38;
max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38;
break;
case DataType.float16:
min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0)
max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0)
break;
default:
throw new Error('Unsupport data type');
}

return createAttributeWithCacheKey({ min, max });
};

export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
context.inputs[0],
'Clip',
(a) => `clamp(${a}, clip_min_, clip_max_)`,
`
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
(a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`,
undefined,
attributes.cacheKey,
undefined,
[
{ type: context.inputs[0].dataType, data: attributes.min },
{ type: context.inputs[0].dataType, data: attributes.max },
],
[
{ name: 'min', type: dataType as UniformDataElementType },
{ name: 'max', type: dataType as UniformDataElementType },
],
),
{ inputs: [0] },
);
Expand Down Expand Up @@ -302,9 +350,7 @@ export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttr
context.inputs[0],
'HardSigmoid',
(a) =>
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
attributes.beta
})))`,
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`,
undefined,
attributes.cacheKey,
),
Expand Down
Loading

0 comments on commit 7dce2af

Please sign in to comment.