diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 5849ce6cffdfe..afd5d55664160 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -204,8 +204,8 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean // allocating one buffer per thread. That's why the implementation is not done. // The user should define a model with a present_key even if not used if past_key is not null. - ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr), - "The implementation only supports past_key and present_key both null or both not null."); + ORT_ENFORCE(!((past_key != nullptr) && (present_key == nullptr)), + "The implementation does not support past_key provided and present_key being null."); const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H @@ -529,8 +529,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu T* present_value, // present value only (if not using present state) bool transpose_output, // whether to transpose the output (0, 2, 1, 3) ThreadPool* tp) const { - ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr), - "The implementation only supports past_value and present_value both null or both not null."); + ORT_ENFORCE(!((past_value != nullptr) && (present_value == nullptr)), + "The implementation does not support past_value provided and present_value being null."); const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 5b970aa12f1d1..894130d6ee991 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -1267,5 +1267,116 @@ TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) { ); } +// Test for attention when past_key and past_value are nullptr. +// This test verifies the updated logic handles the case correctly when no past state is provided. +TEST(AttentionTest, AttentionNoPastKeyValue) { + int batch_size = 1; // Q.shape[0] + int q_num_heads = 2; // Q.shape[1] + int q_sequence_length = 2; // Q.shape[2] + int head_size = 3; // Q.shape[3] + int kv_sequence_length = 2; // K.shape[2] and V.shape[2] + int kv_num_heads = 2; // K.shape[1] and V.shape[1] + int v_head_size = 3; // V.shape[3] + int past_sequence_length = 0; // No past state + + std::vector q = { + 0.548814f, 0.715189f, 0.602763f, + 0.423655f, 0.645894f, 0.437587f, + 0.963663f, 0.383442f, 0.791725f, + 0.568045f, 0.925597f, 0.071036f}; + + std::vector k = { + 0.186193f, 0.944372f, 0.739551f, + 0.227415f, 0.254356f, 0.058029f, + 0.311796f, 0.696343f, 0.377752f, + 0.024679f, 0.067250f, 0.679393f}; + + std::vector v = { + 0.070870f, 0.292794f, 0.152355f, + 0.131289f, 0.604118f, 0.382808f, + 0.967795f, 0.546885f, 0.274824f, + 0.896761f, 0.406733f, 0.552078f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector expected_y = { + 0.477184f, 0.407899f, 0.207834f, + 0.521223f, 0.503569f, 0.469034f, + 0.485670f, 0.410303f, 0.208993f, + 0.487087f, 0.512371f, 0.461486f}; + + // Test with no past_key or past_value (empty vectors) + // The empty vectors will cause AddOptionalInputEdge to be called, resulting in nullptr tensors + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, + expected_y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test for attention with present_key/present_value outputs but no past state +// This ensures the ConcatStateChunk logic works correctly when past_key/past_value are nullptr +TEST(AttentionTest, AttentionNoPastWithPresentOutput) { + int batch_size = 1; // Q.shape[0] + int q_num_heads = 2; // Q.shape[1] + int q_sequence_length = 2; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 2; // K.shape[2] and V.shape[2] + int kv_num_heads = 2; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // No past state + + std::vector q = { + 0.5f, 0.8f, + 0.3f, 0.9f, + 0.7f, 0.4f, + 0.6f, 0.2f}; + + std::vector k = { + 0.1f, 0.7f, + 0.4f, 0.6f, + 0.8f, 0.3f, + 0.2f, 0.9f}; + + std::vector v = { + 1.0f, 2.0f, + 3.0f, 4.0f, + 0.5f, 1.5f, + 2.5f, 3.5f}; + + // The output key/value has 4d shapes. So the data is transposed + std::vector expected_present_key = { + 0.1f, 0.7f, + 0.8f, 0.3f, + 0.4f, 0.6f, + 0.2f, 0.9f}; + + std::vector expected_present_value = { + 1.0f, 2.0f, + 0.5f, 1.5f, + 3.0f, 4.0f, + 2.5f, 3.5f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector expected_y = { + 0.747348f, 1.747348f, + 2.731472f, 3.731472f, + 0.720963f, 1.720963f, + 2.755302f, 3.755302f}; + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, + expected_y, expected_present_key, expected_present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + } // namespace test } // namespace onnxruntime