@@ -1576,3 +1576,55 @@ def test_get_prompt_request_pieces_sorts(
1576
1576
if new_value != current_value :
1577
1577
if any (o .conversation_id == current_value for o in response [response .index (obj ) :]):
1578
1578
assert False , "Conversation IDs are not grouped together"
1579
+
1580
+
1581
+ def test_get_prompt_request_pieces_calls_populate_prompt_piece_scores (
1582
+ duckdb_instance : MemoryInterface , sample_conversations : list [PromptRequestPiece ]
1583
+ ):
1584
+ conversation_id = sample_conversations [0 ].conversation_id
1585
+ duckdb_instance .add_request_pieces_to_memory (request_pieces = sample_conversations )
1586
+
1587
+ with patch .object (duckdb_instance , "populate_prompt_piece_scores" ) as mock_populate :
1588
+ duckdb_instance .get_prompt_request_pieces (conversation_id = conversation_id )
1589
+ assert mock_populate .called
1590
+
1591
+
1592
+ def test_populate_prompt_piece_scores_duplicate_piece (duckdb_instance : MemoryInterface ):
1593
+ original_id = uuid4 ()
1594
+ duplicate_id = uuid4 ()
1595
+
1596
+ pieces = [
1597
+ PromptRequestPiece (
1598
+ id = original_id ,
1599
+ role = "assistant" ,
1600
+ original_value = "prompt text" ,
1601
+ ),
1602
+ PromptRequestPiece (
1603
+ id = duplicate_id ,
1604
+ role = "assistant" ,
1605
+ original_value = "prompt text" ,
1606
+ original_prompt_id = original_id ,
1607
+ ),
1608
+ ]
1609
+
1610
+ duckdb_instance .add_request_pieces_to_memory (request_pieces = pieces )
1611
+
1612
+ score = Score (
1613
+ score_value = str (0.8 ),
1614
+ score_value_description = "Sample description" ,
1615
+ score_type = "float_scale" ,
1616
+ score_category = "Sample category" ,
1617
+ score_rationale = "Sample rationale" ,
1618
+ score_metadata = "Sample metadata" ,
1619
+ prompt_request_response_id = original_id ,
1620
+ )
1621
+ duckdb_instance .add_scores_to_memory (scores = [score ])
1622
+
1623
+ duckdb_instance .populate_prompt_piece_scores (pieces )
1624
+
1625
+ assert len (pieces [0 ].scores ) == 1
1626
+ assert pieces [0 ].scores [0 ].score_value == "0.8"
1627
+
1628
+ # Check that the duplicate piece has the same score as the original
1629
+ assert len (pieces [1 ].scores ) == 1
1630
+ assert pieces [1 ].scores [0 ].score_value == "0.8"
0 commit comments