Skip to content

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Jan 15, 2026

Reland reason

Reland #26466

The previous PR was reverted because it fails on the test:

  1. Windows GPU CUDA CI Pipeline Test Job
  2. Windows GPU TensorRT CI Pipeline Test Job

This PR includes the correct fix.


Description

This pull request introduces significant improvements and expanded support for multi-head attention kernels in ONNX Runtime, particularly focusing on supporting both 3D (BSNH) and 4D (BNSH) QKV input formats. The changes enhance flexibility, correctness, and maintainability for attention operations across CPU and CUDA implementations.

Expanded QKV Input Format Support

  • Added support for 4D QKV input format (Q_K_V_BNSH) in CUDA attention kernels, including proper handling for both cases with and without past/present states, and enforcing that bias is not supported for this format. This includes logic to avoid unnecessary transposes and to write outputs directly when possible. [1] [2] [3] [4] [5] [6] [7]

Kernel and Operator Documentation Updates

  • Updated OperatorKernels.md to document the new Attention operator inputs and outputs for both 3D and 4D formats, specifying supported tensor types for each input.

Correctness and Consistency Fixes

  • Fixed the computation of causal attention indices in CUDA softmax kernels by clarifying and correcting the offset calculation for causal masking. [1] [2] [3] [4]
  • Updated workspace allocation logic for QKV preparation to ensure correct workspace usage for new formats.

Attention Parameter and Helper Refactoring

  • Added is_output_bnsh field to AttentionParameters to indicate output format and updated logic to use this for output placement and transposition decisions. [1] [2]
  • Refactored CPU attention implementation to use the new attention_helper namespace for output mode enums and output shape computation, improving code clarity and maintainability. [1] [2] [3]

Minor Cleanups

  • Removed outdated asserts and improved debug output strings for QKV preparation functions to clarify format and state handling. [1] [2] [3]

These changes collectively improve the flexibility, correctness, and maintainability of attention kernel implementations in ONNX Runtime, especially for advanced transformer models and large language model workloads.

NOT supported in this PR

  • Boolean mask
  • GQA
  • Softcap
  • Softmax precision
  • qk_output_mode other than -1 and 0

This pull request introduces significant improvements and expanded
support for multi-head attention kernels in ONNX Runtime, particularly
focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input
formats. The changes enhance flexibility, correctness, and
maintainability for attention operations across CPU and CUDA
implementations.

### Expanded QKV Input Format Support

* Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention
kernels, including proper handling for both cases with and without
past/present states, and enforcing that bias is not supported for this
format. This includes logic to avoid unnecessary transposes and to write
outputs directly when possible.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388)
[[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435)
[[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
[[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748)
[[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)

### Kernel and Operator Documentation Updates

* Updated `OperatorKernels.md` to document the new `Attention` operator
inputs and outputs for both 3D and 4D formats, specifying supported
tensor types for each input.

### Correctness and Consistency Fixes

* Fixed the computation of causal attention indices in CUDA softmax
kernels by clarifying and correcting the offset calculation for causal
masking.
[[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168)
[[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244)
[[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336)
[[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442)
* Updated workspace allocation logic for QKV preparation to ensure
correct workspace usage for new formats.

### Attention Parameter and Helper Refactoring

* Added `is_output_bnsh` field to `AttentionParameters` to indicate
output format and updated logic to use this for output placement and
transposition decisions.
[[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37)
[[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
* Refactored CPU attention implementation to use the new
`attention_helper` namespace for output mode enums and output shape
computation, improving code clarity and maintainability.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149)

### Minor Cleanups

* Removed outdated asserts and improved debug output strings for QKV
preparation functions to clarify format and state handling.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)

These changes collectively improve the flexibility, correctness, and
maintainability of attention kernel implementations in ONNX Runtime,
especially for advanced transformer models and large language model
workloads.

**NOT supported in this PR**
- Boolean mask
- GQA
- Softcap
- Softmax precision
- qk_output_mode other than -1 and 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants