-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Attenion(23) CUDA #26466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Attenion(23) CUDA #26466
Conversation
This reverts commit 4db63bc.
| stream, total_sequence_length, sequence_length, batch_size, num_heads, | ||
| mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, | ||
| data.scratch, scratch2, parameters.is_unidirectional)); | ||
| data.scratch, scratch2, parameters.is_unidirectional, parameters.past_sequence_length)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need not pass this since you can deduce it from 2nd and 3rd parameters:
past_sequence_length = total_sequence_length - sequence_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using parameter.kv_sequence_length here. It turns out when it's cross attention with causal, the offset needs to be calculated as "total_sequence_length - kv_sequence_length". LMK what you think about the changes.
| ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, | ||
| device_prop.maxThreadsPerBlock, false, temp_output, data.output)); | ||
| } | ||
| DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shall be moved inside if (!parameters.output_is_Q_K_V_BNSH)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "Attention Output" debugging can be helpful to both cases. data.output by this point should have the result in both cases. Could you elaborate more on the concern?
| float scale; | ||
| bool use_tf32; | ||
| bool is_4d_input = false; | ||
| bool output_is_Q_K_V_BNSH; // whether the output format is Q_K_V_BNSH |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about is_output_bnsh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Fixes #24554
This pull request introduces support for 4D QKV input tensors in the Attention operator, specifically in the unfused CUDA kernel for Attention-23, and refines the handling of sequence lengths and causal masking for improved correctness and flexibility. The changes touch on both the operator documentation and several CUDA implementation files, enabling new input formats and ensuring correct output shapes and masking in various scenarios.
Support for 4D QKV input and improved QKV format handling:
Added support for the
Q_K_V_BNSH(4D) QKV input format in the CUDA Attention kernel, including updates toAttentionParametersandAttentionDatastructs to track whether the input is 4D. The kernel now writes the output directly when the input is 4D, avoiding unnecessary transposes. [1] [2] [3] [4] [5] [6] [7] [8]The QKV preparation logic (
PrepareQkv_MHA_NoPast,PrepareQkv_MHA_WithPast_NoBias, andPrepareQkv_MultiHeadAttention) now supports both 3D (BSNH) and 4D (BNSH) input formats, with appropriate assertions and error handling for unsupported scenarios (e.g., bias with 4D input). [1] [2] [3] [4] [5] [6] [7]Causal masking and sequence length correctness:
Refactored the causal masking logic in softmax CUDA kernels to use the explicit
past_sequence_lengthparameter, ensuring correct attention windowing for both incremental and non-incremental decoding. This affects both small and large softmax kernels and their raw-mask variants. [1] [2] [3] [4] [5] [6] [7] [8]Updated kernel launch sites to propagate the new
past_sequence_lengthargument, ensuring all relevant CUDA calls use the correct sequence length for masking. [1] [2]Sequence length variable naming and usage:
sequence_lengthandkv_sequence_lengthin the context of QKV-to-context computations, ensuring the correct sequence length is used for key/value tensors. [1] [2]Documentation update:
OperatorKernels.md) to describe the new input/output signatures and supported types for the Attention operator version 23+, reflecting support for bfloat16 and the new 4D input format.NOT supported in this PR