-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Support boolean attention mask in Attention(23) CUDA #27129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
titaiwangms
wants to merge
5
commits into
main
Choose a base branch
from
titaiwang/support_bool_mask_in_attn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
2992be7
support bool mask
titaiwangms c7b5511
lint
titaiwangms e8e7bb6
Merge branch 'main' into titaiwang/support_bool_mask_in_attn
titaiwangms 8f71398
address comments
titaiwangms 8943c5b
Merge branch 'main' into titaiwang/support_bool_mask_in_attn
titaiwangms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
onnxruntime/core/providers/cuda/llm/attention_mask_convert.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cuda/cu_inc/common.cuh" | ||
| #include "core/providers/cuda/shared_inc/cuda_call.h" | ||
| #include "core/providers/cuda/llm/attention_mask_convert.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| template <typename T> | ||
| __global__ void ConvertBoolMaskToFloatBiasKernel( | ||
| T* output, | ||
| const bool* input, | ||
| CUDA_LONG size, | ||
| T mask_filter_value) { | ||
| CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| if (idx < size) { | ||
| // Boolean mask semantics: true = attend (bias 0.0), false = mask out (bias mask_filter_value) | ||
| output[idx] = input[idx] ? static_cast<T>(0.0f) : mask_filter_value; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status LaunchConvertBoolMaskToFloatBias( | ||
| cudaStream_t stream, | ||
| T* output, | ||
| const bool* input, | ||
| int64_t size, | ||
| T mask_filter_value) { | ||
| if (size <= 0) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Use CeilDiv to safely compute grid size and avoid integer overflow | ||
| int grid_size = static_cast<int>(CeilDiv(size, static_cast<int64_t>(GridDim::maxThreadsPerBlock))); | ||
|
|
||
| ConvertBoolMaskToFloatBiasKernel<T><<<grid_size, GridDim::maxThreadsPerBlock, 0, stream>>>( | ||
| output, input, static_cast<CUDA_LONG>(size), mask_filter_value); | ||
|
|
||
| return CUDA_CALL(cudaGetLastError()); | ||
| } | ||
|
|
||
| // Explicit template instantiations | ||
| template Status LaunchConvertBoolMaskToFloatBias<float>( | ||
| cudaStream_t stream, float* output, const bool* input, int64_t size, float mask_filter_value); | ||
|
|
||
| template Status LaunchConvertBoolMaskToFloatBias<half>( | ||
| cudaStream_t stream, half* output, const bool* input, int64_t size, half mask_filter_value); | ||
|
|
||
| template Status LaunchConvertBoolMaskToFloatBias<BFloat16>( | ||
| cudaStream_t stream, BFloat16* output, const bool* input, int64_t size, BFloat16 mask_filter_value); | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
23 changes: 23 additions & 0 deletions
23
onnxruntime/core/providers/cuda/llm/attention_mask_convert.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/common/common.h" | ||
| #include <cuda_runtime.h> | ||
|
Check warning on line 7 in onnxruntime/core/providers/cuda/llm/attention_mask_convert.h
|
||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| // Convert boolean attention mask to float attention bias. | ||
| // Boolean semantics: true = attend (output 0.0), false = mask out (output mask_filter_value) | ||
| template <typename T> | ||
| Status LaunchConvertBoolMaskToFloatBias( | ||
| cudaStream_t stream, | ||
| T* output, // [size] output buffer (float type) | ||
| const bool* input, // [size] input buffer (bool type) | ||
| int64_t size, // total number of elements | ||
| T mask_filter_value); // value to use for masked positions (typically -10000.0f) | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -448,6 +448,7 @@ TEST(AttentionTest, Attention4DDefault) { | |
| ); | ||
| } | ||
|
|
||
| // Edge case where attention mask blocks all tokens (all false) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xadupre I can't pass this. IIUC, this means attn_mask is blocking every token? |
||
| TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { | ||
| int batch_size = 2; // Q.shape[0] | ||
| int q_num_heads = 3; // Q.shape[1] | ||
|
|
@@ -617,7 +618,7 @@ TEST(AttentionTest, Attention4DAttnMaskBool) { | |
| q, k, v, std::vector<float>(), m, std::vector<float>(), std::vector<float>(), | ||
| -1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type | ||
| y, std::vector<float>(), std::vector<float>(), std::vector<float>(), | ||
| false, true, true // disable_cpu, disable_cuda, disable_dml | ||
| false, false, true // disable_cpu, disable_cuda, disable_dml | ||
| ); | ||
| } | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attn_mask->Data<T>()will raise exception if attn_mask is bool type. Did you add test case for this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment and exception.