Skip to content
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
// However, if present_key is not requested, we should avoid allocated more memory than needed but that mean
// allocating one buffer per thread. That's why the implementation is not done.
// The user should define a model with a present_key even if not used if past_key is not null.
ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr),
"The implementation only supports past_key and present_key both null or both not null.");
ORT_ENFORCE(!((past_key != nullptr) && (present_key == nullptr)),
"The implementation does not support past_key provided and present_key being null.");
const size_t past_chunk_length = static_cast<size_t>(parameters.past_sequence_length) * parameters.head_size; // P x H
const size_t q_input_chunk_length = static_cast<size_t>(parameters.q_sequence_length) * parameters.head_size; // S x H
const size_t k_input_chunk_length = static_cast<size_t>(parameters.kv_sequence_length) * parameters.head_size; // L x H
Expand Down Expand Up @@ -529,8 +529,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
T* present_value, // present value only (if not using present state)
bool transpose_output, // whether to transpose the output (0, 2, 1, 3)
ThreadPool* tp) const {
ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr),
"The implementation only supports past_value and present_value both null or both not null.");
ORT_ENFORCE(!((past_value != nullptr) && (present_value == nullptr)),
"The implementation does not support past_value provided and present_value being null.");
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
const ptrdiff_t v_input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v
Expand Down
111 changes: 111 additions & 0 deletions onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1267,5 +1267,116 @@ TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) {
);
}

// Test for attention when past_key and past_value are nullptr.
// This test verifies the updated logic handles the case correctly when no past state is provided.
TEST(AttentionTest, AttentionNoPastKeyValue) {
int batch_size = 1; // Q.shape[0]
int q_num_heads = 2; // Q.shape[1]
int q_sequence_length = 2; // Q.shape[2]
int head_size = 3; // Q.shape[3]
int kv_sequence_length = 2; // K.shape[2] and V.shape[2]
int kv_num_heads = 2; // K.shape[1] and V.shape[1]
int v_head_size = 3; // V.shape[3]
int past_sequence_length = 0; // No past state

std::vector<float> q = {
0.548814f, 0.715189f, 0.602763f,
0.423655f, 0.645894f, 0.437587f,
0.963663f, 0.383442f, 0.791725f,
0.568045f, 0.925597f, 0.071036f};

std::vector<float> k = {
0.186193f, 0.944372f, 0.739551f,
0.227415f, 0.254356f, 0.058029f,
0.311796f, 0.696343f, 0.377752f,
0.024679f, 0.067250f, 0.679393f};

std::vector<float> v = {
0.070870f, 0.292794f, 0.152355f,
0.131289f, 0.604118f, 0.382808f,
0.967795f, 0.546885f, 0.274824f,
0.896761f, 0.406733f, 0.552078f};

ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size);
ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size);
ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size);

std::vector<float> expected_y = {
0.477184f, 0.407899f, 0.207834f,
0.521223f, 0.503569f, 0.469034f,
0.485670f, 0.410303f, 0.208993f,
0.487087f, 0.512371f, 0.461486f};

// Test with no past_key or past_value (empty vectors)
// The empty vectors will cause AddOptionalInputEdge to be called, resulting in nullptr tensors
RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
q, k, v, std::vector<float>(), std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat,
expected_y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
false, true, true // disable_cpu, disable_cuda, disable_dml
);
}

// Test for attention with present_key/present_value outputs but no past state
// This ensures the ConcatStateChunk logic works correctly when past_key/past_value are nullptr
TEST(AttentionTest, AttentionNoPastWithPresentOutput) {
int batch_size = 1; // Q.shape[0]
int q_num_heads = 2; // Q.shape[1]
int q_sequence_length = 2; // Q.shape[2]
int head_size = 2; // Q.shape[3]
int kv_sequence_length = 2; // K.shape[2] and V.shape[2]
int kv_num_heads = 2; // K.shape[1] and V.shape[1]
int v_head_size = 2; // V.shape[3]
int past_sequence_length = 0; // No past state

std::vector<float> q = {
0.5f, 0.8f,
0.3f, 0.9f,
0.7f, 0.4f,
0.6f, 0.2f};

std::vector<float> k = {
0.1f, 0.7f,
0.4f, 0.6f,
0.8f, 0.3f,
0.2f, 0.9f};

std::vector<float> v = {
1.0f, 2.0f,
3.0f, 4.0f,
0.5f, 1.5f,
2.5f, 3.5f};

// The output key/value has 4d shapes. So the data is transposed
std::vector<float> expected_present_key = {
0.1f, 0.7f,
0.8f, 0.3f,
0.4f, 0.6f,
0.2f, 0.9f};

std::vector<float> expected_present_value = {
1.0f, 2.0f,
0.5f, 1.5f,
3.0f, 4.0f,
2.5f, 3.5f};

ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size);
ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size);
ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size);

std::vector<float> expected_y = {
0.747348f, 1.747348f,
2.731472f, 3.731472f,
0.720963f, 1.720963f,
2.755302f, 3.755302f};

RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
q, k, v, std::vector<float>(), std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat,
expected_y, expected_present_key, expected_present_value, std::vector<float>(),
false, true, true // disable_cpu, disable_cuda, disable_dml
);
}

} // namespace test
} // namespace onnxruntime
Loading