Skip to content

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Oct 31, 2025

Fixes #24554
This pull request introduces support for 4D QKV input tensors in the Attention operator, specifically in the unfused CUDA kernel for Attention-23, and refines the handling of sequence lengths and causal masking for improved correctness and flexibility. The changes touch on both the operator documentation and several CUDA implementation files, enabling new input formats and ensuring correct output shapes and masking in various scenarios.

Support for 4D QKV input and improved QKV format handling:

  • Added support for the Q_K_V_BNSH (4D) QKV input format in the CUDA Attention kernel, including updates to AttentionParameters and AttentionData structs to track whether the input is 4D. The kernel now writes the output directly when the input is 4D, avoiding unnecessary transposes. [1] [2] [3] [4] [5] [6] [7] [8]

  • The QKV preparation logic (PrepareQkv_MHA_NoPast, PrepareQkv_MHA_WithPast_NoBias, and PrepareQkv_MultiHeadAttention) now supports both 3D (BSNH) and 4D (BNSH) input formats, with appropriate assertions and error handling for unsupported scenarios (e.g., bias with 4D input). [1] [2] [3] [4] [5] [6] [7]

Causal masking and sequence length correctness:

  • Refactored the causal masking logic in softmax CUDA kernels to use the explicit past_sequence_length parameter, ensuring correct attention windowing for both incremental and non-incremental decoding. This affects both small and large softmax kernels and their raw-mask variants. [1] [2] [3] [4] [5] [6] [7] [8]

  • Updated kernel launch sites to propagate the new past_sequence_length argument, ensuring all relevant CUDA calls use the correct sequence length for masking. [1] [2]

Sequence length variable naming and usage:

  • Renamed and corrected the use of sequence_length and kv_sequence_length in the context of QKV-to-context computations, ensuring the correct sequence length is used for key/value tensors. [1] [2]

Documentation update:

  • Updated the operator documentation (OperatorKernels.md) to describe the new input/output signatures and supported types for the Attention operator version 23+, reflecting support for bfloat16 and the new 4D input format.

NOT supported in this PR

  • Boolean mask
  • GQA
  • Softcap
  • Softmax precision
  • qk_output_mode other than -1 and 0

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Nov 19, 2025
@titaiwangms titaiwangms marked this pull request as ready for review December 16, 2025 23:47
stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, scratch2, parameters.is_unidirectional, parameters.past_sequence_length));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need not pass this since you can deduce it from 2nd and 3rd parameters:
past_sequence_length = total_sequence_length - sequence_length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using parameter.kv_sequence_length here. It turns out when it's cross attention with causal, the offset needs to be calculated as "total_sequence_length - kv_sequence_length". LMK what you think about the changes.

ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output));
}
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shall be moved inside if (!parameters.output_is_Q_K_V_BNSH)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "Attention Output" debugging can be helpful to both cases. data.output by this point should have the result in both cases. Could you elaborate more on the concern?

float scale;
bool use_tf32;
bool is_4d_input = false;
bool output_is_Q_K_V_BNSH; // whether the output format is Q_K_V_BNSH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about is_output_bnsh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Implement Attention-23

3 participants