Skip to content

Conversation

jatinwadhwa921
Copy link

Back merging with Msft commits

qjia7 and others added 23 commits April 8, 2025 08:27
This PR adds the flash decoding support to optimization the generation
speed when the total sequence length is large. Previously, when the
total sequence length is big enough, the softmax and softmax * v shaders
will become the bottleneck since it only uses limited gpu cores. In this
changes, we add the flash decoding support to split the present
key/value based on the total sequence length, then do reduce to get the
final result.

On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for
phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static
kv cache

Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each
split), 2) the splited qkv results with shape [B, N, split_k, H], which
increase the memory size.

TODO:
Ideally, there should only be two shaders, which can also reduce the
intermediate memory. The computeQKT can be merged into split shader and
do the final softmax adjustment in the reduce shader. However, I meet
some issues that when the total sequence length exceeds some value, the
result will become garbage. Since I can't resolve it in a short time,
leave it in as TODO to fix it in future.
### Description
Use wasm_f32x4_relaxed_max and wasm_f32x4_relaxed_min in WASM relaxed
SIMD build.


### Motivation and Context
This PR replaces wasm_f32x4_min/max with the relaxed SIMD counterparts
wasm_f32x4_relaxed_min/max in WASM relaxed SIMD build.

According to [relaxed SIMD
proposal](https://github.com/WebAssembly/relaxed-simd/blob/main/proposals/relaxed-simd/Overview.md#relaxed-min-and-max),
the wasm_f32x4_relaxed_min/max allow implementation-defined behavior on
NaN propagation and -0.0 vs +0.0. This enables WASM runtimes to use
minps/maxps on x64 platforms and improves the performance.

e.g. for wasm_f32x4_max -> wasm_f32x4_relaxed_max
wasm_f32x4_max: [implementation in
V8](https://source.chromium.org/chromium/chromium/src/+/main:v8/src/codegen/shared-ia32-x64/macro-assembler-shared-ia32-x64.cc;l=231)
wasm_f32x4_relaxed_max: maxps

This change would affect kernel functions rely on MlasMaximumFloat32x4
and MlasMinimumFloat32x4, including various activations and reduced
min/max kernels. In mlas micro bench "COMPUTESOFTMAXINPLACE...", this
change provides a performance improvement of up to 60% on x64 devices.
webgpu support for DequantizeLinear
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
`nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH` [has been
deprecated since
10.0](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/namespacenvinfer1.html#aa8f406be96c14b7dbea548cf19f09a08)
and is always implicitly set for versions 10.0+. Change the EP code to
only set this flag for TRT versions 8 and below.

### Motivation and Context

Removes deprecated API usages in the TRT EP code.

Signed-off-by: Kevin Chen <[email protected]>
"channels" should be validated before divided by "components".
"components" should be passed to program inputs and outputs. Rename
"input" to "x" to match "ErfImpl".
Correct the last dimension of output shape.
"channels" should be validated before divided by "components". 
"components" should be passed to program inputs and outputs.
If the sizes of batch_size and sequence_length are ones, split the
hidden_size to improve parallelism.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
hipClang does not support -Wno-interference-size.
Hence remove the option to avoid build error.
### Description

This PR revised the flag `ort.env.wasm.simd` to enhance its usage so
that more use scenarios are covered.
- Allow setting to `false` explicitly to disable SIMD checking. resolves
microsoft#24292 (@Eldow)
- Allow setting to `'relaxed'` to enable Relaxed SIMD checking. Relaxed
SIMD is introduced first in microsoft#22794 (@jing-bao)
- Behavior is not changed when not setting (ie. `undefined`) or setting
to `true`
- Added a warning message when setting to unknown value, and reset to
`false` in this case
…t#24350)

### Description
Exclude zero-dim input testcase for WebGPU EP.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
1. Split build.py to two files, because currently the file is over 3000
lines. This PR moves 900 of them to a new file.
2. Put the build args into groups. It makes more explicit that "--x86",
"--arm", "--arm64" and "--arm64ec" args are for Windows only.
3. Remove the "--use_avx512" and "--gen-api-doc" build args, as they are
not referenced anywhere. "--gen-api-doc" was for generating documents
for pytorch frontend.
4. Remove MPI related build flags.
5. Delete tools/ci_build/github/pai/orttraining-ci.yml
6. Remove --use_preinstalled_eigen and --eigen_path. Now we have a more
unified approach for all ORT's dependencies (not just eigen). See
https://onnxruntime.ai/docs/build/dependencies.html for more
information.
7. Windows specific build options won't show up on non-Windows
platforms. The same for macOS.
### Description

This PR is one of a series of changes for optimization of Dawn API
usage. See microsoft#24281

Optimize the code for workgroup dispatch in the `WebGpuContext` class.

The updated code prefers using the C-API instead of the C++ API for
WebGPU. This is because the C++ API uses class `wgpu::Buffer`, which
causes significant amount of calls to `wgpuBufferAddRef` and
`wgpuBufferRelease` to ensure the lifecycle of the buffer is managed
correctly. For this specific use case in ONNX Runtime (launch a compute
shader program), using the C-API is more efficient.
### Description

1. Added 'ProcessInt64Tensors' method in BaseOpBuilder to handle common input processing to the graph.
2. Added logic in ProcessOutputs to handle common Cast addition at output.
3. Adds Cast Op at the input to convert to int32 for graph input.
4. Initializers and activation inputs are handled by casting int64_t data to int32_t for QNN compatibility by resizing and copying data.
5. Modified `TransposeOpBuilder` and `GatherOpBuilder`to handle processing outputs.
6. Added unit test for a Reshape op to run with int64 inputs
Support WebGPU build for android and ios
- Add Java API for android test
- Patch for dawn to reduce warnings for UNSAFE_BUFFER_USAGE
### Description

Resolves microsoft#24343. Also
added a test case to avoid breaking the module resolution of TypeScript
in the future.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
MatMulNBits op can be simply emulated by DequantizeLinear + Transpose +
MatMul and currently only 4-bit quantization is supported.

Thus the B and zero_points (if present) inputs must be known as
initializers with data type 'uint8' and we need to register them as
'uint4' WebNN constant.

Typically, all initializers are registered as WebNN constants in one
step via `ModelBuilder::RegisterInitializers` before constructing the
WebNN graph. However, due to WebNN doesn't support cast to 'uint4', we
need to defer the registration of these two inputs until the
`MatMulNBitsBuilder::AddToModelBuilderImpl` is invoked.
…oft#24236)

### Description
<!-- Describe your changes. -->
This script can upload local perf log/csv to DB, which can be used as EP
Perf Dashboard external data source.
(Make sure the csv/log-parsing logic match the targeting DB table's
schema )

#### Usage:
* To post csv to db:
`python parse_post_perf.py --kusto-table="<table_name>"
--kusto-conn="<db_link>" --kusto-db="<dashboard_xyz>"
--upload-csv="<path\to\data.csv>"
`
* To parse log from mobile perf log and post to db:
`python parse_post_perf.py --kusto-table="<table_name>"
--kusto-conn="<db_link>" --kusto-db="<dashboard_xyz>"
--parse_mobile_perf --log-file="<path/to/mobile_model_benchmark.log>"
--model="<model_name>" --device-id="<device_name>"
--commit-id="<ort_commit_id>" --ep="<test_backend>"`

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
…ft#24352)

### Description
WebGPU, VitisAI, and DML are missing from the list.

### Motivation and Context
If users misspell a provider name this error should be showing them the
full possibilities. Leaving one out will lead to confusion.

I noticed it when testing new providers in GenAI that the error message
was not up to date.
### Description
<!-- Describe your changes. -->
Adds support for GroupQueryAttention via WebNN matmul, transpose,
reshape, and other operations that follow the logic in the GQA subgraph
below.

```
 Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length
                N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H), G is group size.
    GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length
    Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.

          query      key               value
            |         |                  |
         Reshape   Reshape            Reshape (B,S,H,N)     seqlens_k
            |         |                  |                  /       |
            |         |       past_value |   (scatter_indices*)     |
        q_Transpose   |              \   |   /                      |
        (0,2,1,3)     | past_key    ScatterND-----------------------|------> present_value
             \        |  /              |                           |
present_key<--\----ScatterND         Expand(G)      (attention_bias, one/finfo_min mask*)
               \      |                 |              /
               |   Expand(G)            |             /
               |      |                 |            /
               |  k_Transpose           |           /
               |   (0,1,3,2)            |          /
               |      |                 |         /
            +---------------------------------------+
            |        ScaledDotProductAttention      |
            +---------------------------------------+
                             |
                           output

```
The ScaledDotProductAttention logic is:
```
    ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention
    inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape)
    Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length,
                  N is number of attention heads, H is head size, W is hidden_size

  query         key
    |            |
    +---matmul---+    scale
          |             |
          +-----div-----+   attn_mask
                 |             |
                 +-----add-----+        value
                        |                 |
                        +------matmul-----+
                                 |
                   (0,2,1,3) transpose B,H,S,N -> B,S,H,N
                                 |
                              Reshape B,S,H,N -> B,S,W
                                 |
                               output
```
scatter_indices's calculation:
```
                                                                                               if_prefill (0/1 constant)
                                                                                                    |
        scatter_indices_left_constant             scatter_indices_right_constant           0 ---> Where <--- Cast <---seqlens_k
                      |                                          |                                  |
                      |                                         Add <--------------------------- scatter_pos*
                      |                                          |
                      +--------------------+---------------------+
                                           |
                                      scatter_indices
```

attention_bias's calculation:
```
                  ones_array (shape=B,N,S,P)                                  range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S)
                      |                                                                          |
                   CumSum (axis=3, exclusive=true, reversed=false)                              Add <--- scatter_pos
                      |                                                                          |
                      |                                                                        Expand (shape=P,S)
                      |                                                                          |
                      +-------------------------------> Lesser <------------------------------Transpose (1,0)
                                                           |
                                                  1 ---> Where <--- finfo_min (minimum value of FP32)
                                                           |
                                                      attention_bias
```

*Notes: Now we only support `past_sequence_length ==
total_sequence_length` for GQA.*

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
@jatinwadhwa921 jatinwadhwa921 requested a review from ankitm3k April 10, 2025 06:44
@jatinwadhwa921 jatinwadhwa921 merged commit a5e6e05 into ovep-develop Apr 10, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.