Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(data.past_value == nullptr);
assert(data.present_key == nullptr);
assert(data.present_value == nullptr);
assert(!parameters.is_unidirectional);
// Note: is_unidirectional (causal) is supported by flash attention, memory efficient attention,
// cuDNN flash attention, and unfused kernel. TRT fused runner is only used when !is_unidirectional
// (enforced in MultiHeadAttention::ComputeInternal).
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));

if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang): Continue on these parameters
// Construct AttentionData to pass to QkvToContext
typedef typename ToCudaType<T>::MappedType CudaT;
onnxruntime::contrib::cuda::AttentionData<CudaT> data;
Expand Down Expand Up @@ -220,12 +219,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
data.qkv_format = contribop_parameters.qkv_format;

// TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.)
// For now, set flags to false and let QkvToContext use the unfused path
data.use_flash_attention = false;
data.use_memory_efficient_attention = false;
data.fused_runner = nullptr;
data.fused_cross_attention_kernel = nullptr;
data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused;

// Allocate workspace for Q, K, V processing and scratch buffer
const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);
Expand Down
Loading