Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/attention.h"
#include "core/providers/cpu/llm/attention_helper.h"
#include "core/providers/cuda/llm/attention.h"
#include "core/providers/cuda/llm/attention_mask_convert.h"
#include "contrib_ops/cuda/bert/attention_data.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -127,10 +129,25 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Note: The new Attention op uses attn_mask as attention_bias
// The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length)
// attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions).

// Handle boolean mask conversion
IAllocatorUniquePtr<T> converted_bool_mask_buffer;

if (attn_mask != nullptr) {
// TODO(titaiwang, xadupre): attn_mask bool is not supported yet
// Convert boolean mask to float attention bias if needed
if (attn_mask->IsDataType<bool>()) {
ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA).");
// Allocate space for converted mask
size_t mask_size = SafeInt<size_t>(attn_mask->Shape().Size());
converted_bool_mask_buffer = GetScratchBuffer<T>(mask_size, context->GetComputeStream());

// Launch CUDA kernel to convert: true->0.0f, false->mask_filter_value
typedef typename ToCudaType<T>::MappedType CudaT;
ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToFloatBias<CudaT>(
Stream(context),
reinterpret_cast<CudaT*>(converted_bool_mask_buffer.get()),
attn_mask->Data<bool>(),
static_cast<int64_t>(mask_size),
ToCudaType<T>::FromFloat(mask_filter_value<float>())));
}

size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
Expand All @@ -157,7 +174,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
contribop_parameters.broadcast_attn_bias_dim_1 = false;
}

contribop_parameters.mask_filter_value = -10000.0f;
contribop_parameters.mask_filter_value = mask_filter_value<float>();
contribop_parameters.scale = parameters.scale;
contribop_parameters.use_tf32 = UseTF32();

Expand Down Expand Up @@ -215,7 +232,17 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Set additional fields
data.bias = nullptr; // New Attention op doesn't have bias
if (nullptr != attn_mask) {
data.attention_bias = reinterpret_cast<const CudaT*>(attn_mask->Data<T>());
// Use converted buffer if boolean mask was converted, otherwise use original mask.
// The else branch is safe because: if attn_mask is bool, we convert it above and
// converted_bool_mask_buffer becomes non-null. So this else branch is only reached
// when attn_mask is of type T (per TypeConstraint U = bool | T).
if (converted_bool_mask_buffer) {
data.attention_bias = reinterpret_cast<const CudaT*>(converted_bool_mask_buffer.get());
} else {
ORT_ENFORCE(!attn_mask->IsDataType<bool>(),
"Boolean mask should have been converted to float bias above.");
data.attention_bias = reinterpret_cast<const CudaT*>(attn_mask->Data<T>());
Copy link
Contributor

@tianleiwu tianleiwu Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attn_mask->Data<T>() will raise exception if attn_mask is bool type. Did you add test case for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment and exception.

}
}
data.qkv_format = contribop_parameters.qkv_format;

Expand Down
55 changes: 55 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_convert.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/llm/attention_mask_convert.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
__global__ void ConvertBoolMaskToFloatBiasKernel(
T* output,
const bool* input,
CUDA_LONG size,
T mask_filter_value) {
CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < size) {
// Boolean mask semantics: true = attend (bias 0.0), false = mask out (bias mask_filter_value)
output[idx] = input[idx] ? static_cast<T>(0.0f) : mask_filter_value;
}
}

template <typename T>
Status LaunchConvertBoolMaskToFloatBias(
cudaStream_t stream,
T* output,
const bool* input,
int64_t size,
T mask_filter_value) {
if (size <= 0) {
return Status::OK();
}

// Use CeilDiv to safely compute grid size and avoid integer overflow
int grid_size = static_cast<int>(CeilDiv(size, static_cast<int64_t>(GridDim::maxThreadsPerBlock)));

ConvertBoolMaskToFloatBiasKernel<T><<<grid_size, GridDim::maxThreadsPerBlock, 0, stream>>>(
output, input, static_cast<CUDA_LONG>(size), mask_filter_value);

return CUDA_CALL(cudaGetLastError());
}

// Explicit template instantiations
template Status LaunchConvertBoolMaskToFloatBias<float>(
cudaStream_t stream, float* output, const bool* input, int64_t size, float mask_filter_value);

template Status LaunchConvertBoolMaskToFloatBias<half>(
cudaStream_t stream, half* output, const bool* input, int64_t size, half mask_filter_value);

template Status LaunchConvertBoolMaskToFloatBias<BFloat16>(
cudaStream_t stream, BFloat16* output, const bool* input, int64_t size, BFloat16 mask_filter_value);

} // namespace cuda
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_convert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include <cuda_runtime.h>

Check warning on line 7 in onnxruntime/core/providers/cuda/llm/attention_mask_convert.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: attention_mask_convert.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/llm/attention_mask_convert.h:7: Found C system header after other header. Should be: attention_mask_convert.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {
namespace cuda {

// Convert boolean attention mask to float attention bias.
// Boolean semantics: true = attend (output 0.0), false = mask out (output mask_filter_value)
template <typename T>
Status LaunchConvertBoolMaskToFloatBias(
cudaStream_t stream,
T* output, // [size] output buffer (float type)
const bool* input, // [size] input buffer (bool type)
int64_t size, // total number of elements
T mask_filter_value); // value to use for masked positions (typically -10000.0f)

} // namespace cuda
} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ TEST(AttentionTest, Attention4DDefault) {
);
}

// Edge case where attention mask blocks all tokens (all false)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xadupre I can't pass this. IIUC, this means attn_mask is blocking every token?

TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) {
int batch_size = 2; // Q.shape[0]
int q_num_heads = 3; // Q.shape[1]
Expand Down Expand Up @@ -617,7 +618,7 @@ TEST(AttentionTest, Attention4DAttnMaskBool) {
q, k, v, std::vector<float>(), m, std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
false, true, true // disable_cpu, disable_cuda, disable_dml
false, false, true // disable_cpu, disable_cuda, disable_dml
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
"^test_attention_4d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda
"^test_attention_4d_softcap_cuda", // softcap not supported in Attention-cuda
"^test_attention_4d_with_qk_matmul_softcap_cuda", // softcap not supported in Attention-cuda
"^test_attention_4d_attn_mask_bool_cuda", // bool mask not supported in Attention-cuda
"^test_attention_4d_attn_mask_bool_4d_cuda", // bool mask not supported in Attention-cuda
"^test_attention_3d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda
"^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cuda", // QK matmul + bias not supported in Attention-cuda
"^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cuda", // QK matmul + bias not supported in Attention-cuda
Expand Down
Loading