@@ -27,82 +27,127 @@ def get_row(doc_id: str, score: float, content: str = "content") -> RowMapping:
2727
2828class TestWeightedSumRanking :
2929 def test_empty_inputs (self ) -> None :
30+ """Tests that the function handles empty inputs gracefully."""
3031 results = weighted_sum_ranking ([], [])
3132 assert results == []
3233
33- def test_primary_only (self ) -> None :
34+ def test_primary_only_cosine_default (self ) -> None :
35+ """Tests ranking with only primary results using default cosine distance."""
3436 primary = [get_row ("p1" , 0.8 ), get_row ("p2" , 0.6 )]
35- # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3
36- results = weighted_sum_ranking ( # type: ignore
37+ # --- Calculation (Cosine = lower is better) ---
38+ # Scores: [0.8, 0.6]. Range: 0.2. Min: 0.6.
39+ # p1 norm: 1.0 - ((0.8 - 0.6) / 0.2) = 0.0
40+ # p2 norm: 1.0 - ((0.6 - 0.6) / 0.2) = 1.0
41+ # Weighted (0.5): p1 = 0.0, p2 = 0.5
42+ # Order: p2, p1
43+ results = weighted_sum_ranking (
3744 primary , # type: ignore
3845 [],
39- primary_results_weight = 0.5 ,
40- secondary_results_weight = 0.5 ,
4146 )
4247 assert len (results ) == 2
43- assert results [0 ]["id_val" ] == "p1 "
44- assert results [0 ]["distance" ] == pytest .approx (0.4 )
45- assert results [1 ]["id_val" ] == "p2 "
46- assert results [1 ]["distance" ] == pytest .approx (0.3 )
48+ assert results [0 ]["id_val" ] == "p2 "
49+ assert results [0 ]["distance" ] == pytest .approx (0.5 )
50+ assert results [1 ]["id_val" ] == "p1 "
51+ assert results [1 ]["distance" ] == pytest .approx (0.0 )
4752
4853 def test_secondary_only (self ) -> None :
49- secondary = [get_row ("s1" , 0.9 ), get_row ("s2" , 0.7 )]
50- # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35
54+ """Tests ranking with only secondary (keyword) results."""
55+ secondary = [get_row ("s1" , 15.0 ), get_row ("s2" , 5.0 )]
56+ # --- Calculation (Keyword = higher is better) ---
57+ # Scores: [15.0, 5.0]. Range: 10.0. Min: 5.0.
58+ # s1 norm: (15.0 - 5.0) / 10.0 = 1.0
59+ # s2 norm: (5.0 - 5.0) / 10.0 = 0.0
60+ # Weighted (0.5): s1 = 0.5, s2 = 0.0
61+ # Order: s1, s2
5162 results = weighted_sum_ranking (
5263 [],
5364 secondary , # type: ignore
54- primary_results_weight = 0.5 ,
55- secondary_results_weight = 0.5 ,
5665 )
5766 assert len (results ) == 2
5867 assert results [0 ]["id_val" ] == "s1"
59- assert results [0 ]["distance" ] == pytest .approx (0.45 )
68+ assert results [0 ]["distance" ] == pytest .approx (0.5 )
6069 assert results [1 ]["id_val" ] == "s2"
61- assert results [1 ]["distance" ] == pytest .approx (0.35 )
70+ assert results [1 ]["distance" ] == pytest .approx (0.0 )
6271
63- def test_mixed_results_default_weights (self ) -> None :
72+ def test_mixed_results_cosine (self ) -> None :
73+ """Tests combining cosine (lower is better) and keyword (higher is better) scores."""
6474 primary = [get_row ("common" , 0.8 ), get_row ("p_only" , 0.7 )]
65- secondary = [get_row ("common" , 0.9 ), get_row ("s_only" , 0.6 )]
66- # Weights are 0.5, 0.5
67- # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85
68- # p_only_score = (0.7 * 0.5) = 0.35
69- # s_only_score = (0.6 * 0.5) = 0.30
70- # Order: common (0.85), p_only (0.35), s_only (0.30)
71-
72- results = weighted_sum_ranking (primary , secondary ) # type: ignore
75+ secondary = [get_row ("common" , 9.0 ), get_row ("s_only" , 6.0 )]
76+ # --- Calculation ---
77+ # Primary norm (inverted): common=0.0, p_only=1.0
78+ # Secondary norm: common=1.0, s_only=0.0
79+ # Weighted (0.5):
80+ # common = (0.0 * 0.5) + (1.0 * 0.5) = 0.5
81+ # p_only = (1.0 * 0.5) + 0 = 0.5
82+ # s_only = 0 + (0.0 * 0.5) = 0.0
83+ results = weighted_sum_ranking (
84+ primary , # type: ignore
85+ secondary , # type: ignore
86+ )
7387 assert len (results ) == 3
74- assert results [0 ]["id_val" ] == "common"
75- assert results [0 ]["distance" ] == pytest .approx (0.85 )
76- assert results [1 ]["id_val" ] == "p_only"
77- assert results [1 ]["distance" ] == pytest .approx (0.35 )
88+ # Check that the top two results have the correct score and IDs (order may vary)
89+ top_ids = {res ["id_val" ] for res in results [:2 ]}
90+ assert top_ids == {"common" , "p_only" }
91+ assert results [0 ]["distance" ] == pytest .approx (0.5 )
92+ assert results [1 ]["distance" ] == pytest .approx (0.5 )
7893 assert results [2 ]["id_val" ] == "s_only"
79- assert results [2 ]["distance" ] == pytest .approx (0.30 )
94+ assert results [2 ]["distance" ] == pytest .approx (0.0 )
8095
81- def test_mixed_results_custom_weights (self ) -> None :
82- primary = [get_row ("d1" , 1.0 )] # p_w=0.2 -> 0.2
83- secondary = [get_row ("d1" , 0.5 )] # s_w=0.8 -> 0.4
84- # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6
96+ def test_primary_max_inner_product (self ) -> None :
97+ """Tests using MAX_INNER_PRODUCT (higher is better) for primary search."""
98+ primary = [get_row ("best" , 0.9 ), get_row ("worst" , 0.1 )]
99+ secondary = [get_row ("best" , 20.0 ), get_row ("worst" , 5.0 )]
100+ # --- Calculation ---
101+ # Primary norm (NOT inverted): best=1.0, worst=0.0
102+ # Secondary norm: best=1.0, worst=0.0
103+ # Weighted (0.5):
104+ # best = (1.0 * 0.5) + (1.0 * 0.5) = 1.0
105+ # worst = (0.0 * 0.5) + (0.0 * 0.5) = 0.0
106+ results = weighted_sum_ranking (
107+ primary , # type: ignore
108+ secondary , # type: ignore
109+ distance_strategy = DistanceStrategy .INNER_PRODUCT ,
110+ )
111+ assert len (results ) == 2
112+ assert results [0 ]["id_val" ] == "best"
113+ assert results [0 ]["distance" ] == pytest .approx (1.0 )
114+ assert results [1 ]["id_val" ] == "worst"
115+ assert results [1 ]["distance" ] == pytest .approx (0.0 )
85116
117+ def test_primary_euclidean (self ) -> None :
118+ """Tests using EUCLIDEAN (lower is better) for primary search."""
119+ primary = [get_row ("closer" , 10.5 ), get_row ("farther" , 25.5 )]
120+ secondary = [get_row ("closer" , 100.0 ), get_row ("farther" , 10.0 )]
121+ # --- Calculation ---
122+ # Primary norm (inverted): closer=1.0, farther=0.0
123+ # Secondary norm: closer=1.0, farther=0.0
124+ # Weighted (0.5):
125+ # closer = (1.0 * 0.5) + (1.0 * 0.5) = 1.0
126+ # farther = (0.0 * 0.5) + (0.0 * 0.5) = 0.0
86127 results = weighted_sum_ranking (
87128 primary , # type: ignore
88129 secondary , # type: ignore
89- primary_results_weight = 0.2 ,
90- secondary_results_weight = 0.8 ,
130+ distance_strategy = DistanceStrategy .EUCLIDEAN ,
91131 )
92- assert len (results ) == 1
93- assert results [0 ]["id_val" ] == "d1"
94- assert results [0 ]["distance" ] == pytest .approx (0.6 )
132+ assert len (results ) == 2
133+ assert results [0 ]["id_val" ] == "closer"
134+ assert results [0 ]["distance" ] == pytest .approx (1.0 )
135+ assert results [1 ]["id_val" ] == "farther"
136+ assert results [1 ]["distance" ] == pytest .approx (0.0 )
95137
96138 def test_fetch_top_k (self ) -> None :
139+ """Tests that fetch_top_k correctly limits the number of results."""
97140 primary = [get_row (f"p{ i } " , (10 - i ) / 10.0 ) for i in range (5 )]
98- # Scores: 1.0, 0.9, 0.8, 0.7, 0.6
99- # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3
100- results = weighted_sum_ranking (primary , [], fetch_top_k = 2 ) # type: ignore
141+ # p0=1.0, p1=0.9, p2=0.8, p3=0.7, p4=0.6
142+ # The best scores (lowest distance) are p4 and p3
143+ results = weighted_sum_ranking (
144+ primary , # type: ignore
145+ [],
146+ fetch_top_k = 2 ,
147+ )
101148 assert len (results ) == 2
102- assert results [0 ]["id_val" ] == "p0"
103- assert results [0 ]["distance" ] == pytest .approx (0.5 )
104- assert results [1 ]["id_val" ] == "p1"
105- assert results [1 ]["distance" ] == pytest .approx (0.45 )
149+ assert results [0 ]["id_val" ] == "p4" # Has the best normalized score
150+ assert results [1 ]["id_val" ] == "p3"
106151
107152
108153class TestReciprocalRankFusion :
0 commit comments