Skip to content

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jan 8, 2026

Summary

This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation.

Changes

  • Flash Attention Kernels: Updated/Added Flash Attention forward kernels and headers in onnxruntime/contrib_ops/cuda/bert/flash_attention/.
  • API Extension: Updated mha_fwd and mha_fwd_kvcache in flash_api.h and flash_api.cc to accept two new optional parameters:
    • cache_batch_idx: Indices to index into the KV cache (support for non-contiguous batch indices).
    • leftpad_k: Support for left-padding in the key sequence.
  • Alignment & Fixes:
    • Cleanup: Removed redundant kInfinity definition in flash_fwd_kernel.h.
    • Includes: Added missing <core/providers/cuda/shared_inc/cuda_call.h> in flash_fwd_launch_template.h.
  • Integration: Updated group_query_attention_impl.cu to align with the new mha_fwd_kvcache signature.
  • Build Configuration: Adjusted onnxruntime_providers_cpu.cmake to update the exclusion list for Flash Attention kernels in quick build mode.

Implementation Details

  • The run_mha_fwd helper now checks if cache_batch_idx is provided alongside k_new to determine if the split kernel should be forced.
  • New parameters are propagated through the call stack to the underlying Flash Attention kernels.

@tianleiwu tianleiwu changed the title Update Flash Attention Implementation and APIs [CUDA] Update Flash Attention Implementation and APIs Jan 8, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Flash Attention implementation in ONNX Runtime by syncing with the latest kernels from the upstream flash-attention repository and extending the internal API to support advanced caching scenarios with non-contiguous batch indices and left-padding.

Key Changes:

  • Extended mha_fwd and mha_fwd_kvcache APIs with two new optional parameters: cache_batch_idx (for non-contiguous batch indexing) and leftpad_k (for left-padding support)
  • Introduced namespace configuration system via namespace_config.h for better isolation and flexibility
  • Added explicit causal template parameter to kernel dispatching functions, creating separate instantiations for causal and non-causal attention patterns
  • Updated numerous kernel files to align with the new template signatures and namespace conventions
  • Standardized copyright headers to 2024

Reviewed changes

Copilot reviewed 64 out of 64 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
flash_api.h / flash_api.cc Extended API signatures to accept cache_batch_idx and leftpad_k parameters, updated force split kernel logic
namespace_config.h New file defining FLASH_NAMESPACE macro for namespace management
flash.h Updated function templates to include Is_causal template parameter, added new params struct fields
block_info.h Added leftpad_k field and integrated it into offset calculations
flash_fwd_kernel.h Major updates including LSE layout handling, leftpad support in rotary embeddings, return softmax support
flash_fwd_launch_template.h Updated kernel dispatch logic with causal template parameter, modified smem size handling
kernel_traits.h Changed copy atoms from DefaultCopy to AutoVectorizingCopyWithAssumedAlignment<128> for better performance
utils.h, softmax.h, mask.h, rotary.h Updated namespace declarations from onnxruntime::flash to FLASH_NAMESPACE
flash_fwd_hdim*_*.cu All kernel instantiation files updated with new template signatures (added Is_causal parameter)
flash_fwd_split_hdim*_*.cu All split kernel files updated with new template signatures
update_kernels.py Python script to generate kernel files programmatically
group_query_attention_impl.cu Updated caller to pass two additional nullptr parameters to mha_fwd_kvcache
onnxruntime_providers_cpu.cmake Updated build filter comments for quick build mode

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

kunal-vaishnavi
kunal-vaishnavi previously approved these changes Jan 9, 2026
@tianleiwu tianleiwu enabled auto-merge (squash) January 10, 2026 00:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants