diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 7b37a3a4227b6..852f0bcaff5a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -258,7 +258,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, assert(data.past_value == nullptr); assert(data.present_key == nullptr); assert(data.present_value == nullptr); - assert(!parameters.is_unidirectional); + // Note: is_unidirectional (causal) is supported by flash attention, memory efficient attention, + // cuDNN flash attention, and unfused kernel. TRT fused runner is only used when !is_unidirectional + // (enforced in MultiHeadAttention::ComputeInternal). assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 99f297bba6444..3b7aebc2d4714 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -191,7 +191,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); } - // TODO(titaiwang): Continue on these parameters // Construct AttentionData to pass to QkvToContext typedef typename ToCudaType::MappedType CudaT; onnxruntime::contrib::cuda::AttentionData data; @@ -220,12 +219,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.qkv_format = contribop_parameters.qkv_format; - // TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.) // For now, set flags to false and let QkvToContext use the unfused path data.use_flash_attention = false; data.use_memory_efficient_attention = false; data.fused_runner = nullptr; data.fused_cross_attention_kernel = nullptr; + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; // Allocate workspace for Q, K, V processing and scratch buffer const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);