1+ from typing import cast
2+
13import pytest
4+ from sqlalchemy import RowMapping
25
36from langchain_postgres .v2 .hybrid_search_config import (
47 reciprocal_rank_fusion ,
58 weighted_sum_ranking ,
69)
10+ from langchain_postgres .v2 .indexes import DistanceStrategy
711
812
913# Helper to create mock input items that mimic RowMapping for the fusion functions
10- def get_row (doc_id : str , score : float , content : str = "content" ) -> dict :
14+ def get_row (doc_id : str , score : float , content : str = "content" ) -> RowMapping :
1115 """
1216 Simulates a RowMapping-like dictionary.
1317 The fusion functions expect to extract doc_id as the first value and
@@ -17,7 +21,8 @@ def get_row(doc_id: str, score: float, content: str = "content") -> dict:
1721 # Python dicts maintain insertion order (Python 3.7+).
1822 # This structure ensures list(row.values())[0] is doc_id and
1923 # list(row.values())[-1] is score.
20- return {"id_val" : doc_id , "content_field" : content , "distance" : score }
24+ row_dict = {"id_val" : doc_id , "content_field" : content , "distance" : score }
25+ return cast (RowMapping , row_dict )
2126
2227
2328class TestWeightedSumRanking :
@@ -102,30 +107,31 @@ def test_fetch_top_k(self) -> None:
102107
103108class TestReciprocalRankFusion :
104109 def test_empty_inputs (self ) -> None :
110+ """Tests that the function handles empty inputs gracefully."""
105111 results = reciprocal_rank_fusion ([], [])
106112 assert results == []
107113
108114 def test_primary_only (self ) -> None :
109- primary = [
110- get_row ("p1" , 0.8 ),
111- get_row ("p2" , 0.6 ),
112- ] # p1 rank 0, p2 rank 1
115+ """Tests RRF with only primary results using default cosine (lower is better)."""
116+ primary = [get_row ("p1" , 0.8 ), get_row ("p2" , 0.6 )]
113117 rrf_k = 60
114- # p1_score = 1 / (0 + 60)
115- # p2_score = 1 / (1 + 60)
118+ # --- Calculation (Cosine: lower is better) ---
119+ # Sorted order: p2 (0.6) -> rank 0; p1 (0.8) -> rank 1
120+ # p2_score = 1 / (0 + 60)
121+ # p1_score = 1 / (1 + 60)
116122 results = reciprocal_rank_fusion (primary , [], rrf_k = rrf_k ) # type: ignore
117123 assert len (results ) == 2
118- assert results [0 ]["id_val" ] == "p1 "
124+ assert results [0 ]["id_val" ] == "p2 "
119125 assert results [0 ]["distance" ] == pytest .approx (1.0 / (0 + rrf_k ))
120- assert results [1 ]["id_val" ] == "p2 "
126+ assert results [1 ]["id_val" ] == "p1 "
121127 assert results [1 ]["distance" ] == pytest .approx (1.0 / (1 + rrf_k ))
122128
123129 def test_secondary_only (self ) -> None :
124- secondary = [
125- get_row ("s1" , 0.9 ),
126- get_row ("s2" , 0.7 ),
127- ] # s1 rank 0, s2 rank 1
130+ """Tests RRF with only secondary results (higher is better)."""
131+ secondary = [get_row ("s1" , 0.9 ), get_row ("s2" , 0.7 )]
128132 rrf_k = 60
133+ # --- Calculation (Keyword: higher is better) ---
134+ # Sorted order: s1 (0.9) -> rank 0; s2 (0.7) -> rank 1
129135 results = reciprocal_rank_fusion ([], secondary , rrf_k = rrf_k ) # type: ignore
130136 assert len (results ) == 2
131137 assert results [0 ]["id_val" ] == "s1"
@@ -134,96 +140,130 @@ def test_secondary_only(self) -> None:
134140 assert results [1 ]["distance" ] == pytest .approx (1.0 / (1 + rrf_k ))
135141
136142 def test_mixed_results_default_k (self ) -> None :
137- primary = [get_row ("common" , 0.8 ), get_row ("p_only" , 0.7 )]
138- secondary = [get_row ("common" , 0.9 ), get_row ("s_only" , 0.6 )]
143+ """Tests fusion with default cosine (lower better) and keyword (higher better)."""
144+ primary = [
145+ get_row ("common" , 0.8 ),
146+ get_row ("p_only" , 0.7 ),
147+ ] # Order: p_only, common
148+ secondary = [
149+ get_row ("common" , 0.9 ),
150+ get_row ("s_only" , 0.6 ),
151+ ] # Order: common, s_only
139152 rrf_k = 60
140- # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k
141- # p_only_score = (1/(1+k))_prim = 1/(k+1)
142- # s_only_score = (1/(1+k))_sec = 1/(k+1)
153+ # --- Calculation ---
154+ # common: rank 1 in P (1/61) + rank 0 in S (1/60) -> highest score
155+ # p_only: rank 0 in P (1/60)
156+ # s_only: rank 1 in S (1/61)
143157 results = reciprocal_rank_fusion (primary , secondary , rrf_k = rrf_k ) # type: ignore
144158 assert len (results ) == 3
145159 assert results [0 ]["id_val" ] == "common"
146- assert results [0 ]["distance" ] == pytest .approx (2.0 / rrf_k )
147- # Check the next two elements, their order might vary due to tie in score
148- next_ids = {results [1 ]["id_val" ], results [2 ]["id_val" ]}
149- next_scores = {results [1 ]["distance" ], results [2 ]["distance" ]}
150- assert next_ids == {"p_only" , "s_only" }
151- for score in next_scores :
152- assert score == pytest .approx (1.0 / (1 + rrf_k ))
160+ assert results [0 ]["distance" ] == pytest .approx (1 / 61 + 1 / 60 )
161+ assert results [1 ]["id_val" ] == "p_only"
162+ assert results [1 ]["distance" ] == pytest .approx (1 / 60 )
163+ assert results [2 ]["id_val" ] == "s_only"
164+ assert results [2 ]["distance" ] == pytest .approx (1 / 61 )
153165
154166 def test_fetch_top_k_rrf (self ) -> None :
167+ """Tests that fetch_top_k limits results correctly after fusion."""
168+ # Using cosine distance (lower is better)
155169 primary = [get_row (f"p{ i } " , (10 - i ) / 10.0 ) for i in range (5 )]
156- rrf_k = 1
157- results = reciprocal_rank_fusion (primary , [], rrf_k = rrf_k , fetch_top_k = 2 ) # type: ignore
170+ # Scores: [1.0, 0.9, 0.8, 0.7, 0.6]
171+ # Sorted order: p4 (0.6), p3 (0.7), p2 (0.8), ...
172+ results = reciprocal_rank_fusion (primary , [], fetch_top_k = 2 ) # type: ignore
158173 assert len (results ) == 2
159- assert results [0 ]["id_val" ] == "p0"
160- assert results [0 ]["distance" ] == pytest .approx (1.0 / (0 + rrf_k ))
161- assert results [1 ]["id_val" ] == "p1"
162- assert results [1 ]["distance" ] == pytest .approx (1.0 / (1 + rrf_k ))
174+ assert results [0 ]["id_val" ] == "p4"
175+ assert results [1 ]["id_val" ] == "p3"
163176
164177 def test_rrf_content_preservation (self ) -> None :
178+ """Tests that the data from the first time a document is seen is kept."""
165179 primary = [get_row ("doc1" , 0.9 , content = "Primary Content" )]
166180 secondary = [get_row ("doc1" , 0.8 , content = "Secondary Content" )]
167- # RRF processes primary then secondary. If a doc is in both,
168- # the content from the secondary list will overwrite primary's .
169- results = reciprocal_rank_fusion (primary , secondary , rrf_k = 60 ) # type: ignore
181+ # RRF processes primary first. When "doc1" is seen, its data is stored.
182+ # It will not be overwritten by the "doc1" from the secondary list.
183+ results = reciprocal_rank_fusion (primary , secondary ) # type: ignore
170184 assert len (results ) == 1
171185 assert results [0 ]["id_val" ] == "doc1"
172- assert results [0 ]["content_field" ] == "Secondary Content"
186+ assert results [0 ]["content_field" ] == "Primary Content"
173187
174- # If only in primary
175- results_prim_only = reciprocal_rank_fusion (primary , [] , rrf_k = 60 ) # type: ignore
176- assert results_prim_only [0 ]["content_field" ] == "Primary Content"
188+ # If only in secondary
189+ results_prim_only = reciprocal_rank_fusion ([], secondary , rrf_k = 60 ) # type: ignore
190+ assert results_prim_only [0 ]["content_field" ] == "Secondary Content"
177191
178192 def test_reordering_from_inputs_rrf (self ) -> None :
179- """
180- Tests that RRF fused ranking can be different from both primary and secondary
181- input rankings.
182- Primary Order: A, B, C
183- Secondary Order: C, B, A
184- Fused Order: (A, C) tied, then B
185- """
186- primary = [
187- get_row ("docA" , 0.9 ),
188- get_row ("docB" , 0.8 ),
189- get_row ("docC" , 0.1 ),
190- ]
191- secondary = [
192- get_row ("docC" , 0.9 ),
193- get_row ("docB" , 0.5 ),
194- get_row ("docA" , 0.2 ),
195- ]
196- rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation
197- # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3
198- # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1
199- # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3
193+ """Tests that RRF can produce a ranking different from the inputs."""
194+ primary = [get_row ("docA" , 0.9 ), get_row ("docB" , 0.8 ), get_row ("docC" , 0.1 )]
195+ secondary = [get_row ("docC" , 0.9 ), get_row ("docB" , 0.5 ), get_row ("docA" , 0.2 )]
196+ rrf_k = 1.0
197+ # --- Calculation (Primary sorted ascending, Secondary descending) ---
198+ # Primary ranks: docC (0), docB (1), docA (2)
199+ # Secondary ranks: docC (0), docB (1), docA (2)
200+ # docC_score = 1/(0+1) [P] + 1/(0+1) [S] = 2.0
201+ # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1.0
202+ # docA_score = 1/(2+1) [P] + 1/(2+1) [S] = 2/3
200203 results = reciprocal_rank_fusion (primary , secondary , rrf_k = rrf_k ) # type: ignore
201204 assert len (results ) == 3
202- assert {results [0 ]["id_val" ], results [1 ]["id_val" ]} == {"docA" , "docC" }
203- assert results [0 ]["distance" ] == pytest .approx (4.0 / 3.0 )
204- assert results [1 ]["distance" ] == pytest .approx (4.0 / 3.0 )
205- assert results [2 ]["id_val" ] == "docB"
206- assert results [2 ]["distance" ] == pytest .approx (1.0 )
207-
208- def test_reordering_from_inputs_weighted_sum (self ) -> None :
209- """
210- Tests that the fused ranking can be different from both primary and secondary
211- input rankings.
212- Primary Order: A (0.9), B (0.7)
213- Secondary Order: B (0.8), A (0.2)
214- Fusion (0.5/0.5 weights):
215- docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55
216- docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75
217- Expected Fused Order: docB (0.75), docA (0.55)
218- This is different from Primary (A,B) and Secondary (B,A) in terms of
219- original score, but the fusion logic changes the effective contribution).
220- """
221- primary = [get_row ("docA" , 0.9 ), get_row ("docB" , 0.7 )]
222- secondary = [get_row ("docB" , 0.8 ), get_row ("docA" , 0.2 )]
205+ assert results [0 ]["id_val" ] == "docC"
206+ assert results [0 ]["distance" ] == pytest .approx (2.0 )
207+ assert results [1 ]["id_val" ] == "docB"
208+ assert results [1 ]["distance" ] == pytest .approx (1.0 )
209+ assert results [2 ]["id_val" ] == "docA"
210+ assert results [2 ]["distance" ] == pytest .approx (2.0 / 3.0 )
223211
224- results = weighted_sum_ranking (primary , secondary ) # type: ignore
212+ # --------------------------------------------------------------------------
213+ ## New Tests for Other Strategies and Edge Cases
214+
215+ def test_mixed_results_max_inner_product (self ) -> None :
216+ """Tests fusion with MAX_INNER_PRODUCT (higher is better) for primary."""
217+ primary = [get_row ("best" , 0.9 ), get_row ("worst" , 0.1 )] # Order: best, worst
218+ secondary = [get_row ("best" , 20.0 ), get_row ("worst" , 5.0 )] # Order: best, worst
219+ rrf_k = 10
220+ # best: rank 0 in P + rank 0 in S -> 1/10 + 1/10 = 0.2
221+ # worst: rank 1 in P + rank 1 in S -> 1/11 + 1/11
222+ results = reciprocal_rank_fusion (
223+ primary , # type: ignore
224+ secondary , # type: ignore
225+ rrf_k = rrf_k ,
226+ distance_strategy = DistanceStrategy .INNER_PRODUCT ,
227+ )
228+ assert len (results ) == 2
229+ assert results [0 ]["id_val" ] == "best"
230+ assert results [0 ]["distance" ] == pytest .approx (0.2 )
231+ assert results [1 ]["id_val" ] == "worst"
232+ assert results [1 ]["distance" ] == pytest .approx (2.0 / 11.0 )
233+
234+ def test_mixed_results_euclidean (self ) -> None :
235+ """Tests fusion with EUCLIDEAN (lower is better) for primary."""
236+ primary = [
237+ get_row ("closer" , 10.5 ),
238+ get_row ("farther" , 25.5 ),
239+ ] # Order: closer, farther
240+ secondary = [
241+ get_row ("closer" , 100.0 ),
242+ get_row ("farther" , 10.0 ),
243+ ] # Order: closer, farther
244+ rrf_k = 10
245+ # closer: rank 0 in P + rank 0 in S -> 1/10 + 1/10 = 0.2
246+ # farther: rank 1 in P + rank 1 in S -> 1/11 + 1/11
247+ results = reciprocal_rank_fusion (
248+ primary , # type: ignore
249+ secondary , # type: ignore
250+ rrf_k = rrf_k ,
251+ distance_strategy = DistanceStrategy .EUCLIDEAN ,
252+ )
225253 assert len (results ) == 2
226- assert results [0 ]["id_val" ] == "docB"
227- assert results [0 ]["distance" ] == pytest .approx (0.75 )
228- assert results [1 ]["id_val" ] == "docA"
229- assert results [1 ]["distance" ] == pytest .approx (0.55 )
254+ assert results [0 ]["id_val" ] == "closer"
255+ assert results [0 ]["distance" ] == pytest .approx (0.2 )
256+ assert results [1 ]["id_val" ] == "farther"
257+ assert results [1 ]["distance" ] == pytest .approx (2.0 / 11.0 )
258+
259+ def test_rrf_with_identical_scores (self ) -> None :
260+ """Tests that stable sort is preserved for identical scores."""
261+ # Python's sorted() is stable. p1 appears before p2 in the list.
262+ primary = [get_row ("p1" , 0.5 ), get_row ("p2" , 0.5 )]
263+ rrf_k = 60
264+ # Expected order (stable sort): p1 (rank 0), p2 (rank 1)
265+ results = reciprocal_rank_fusion (primary , []) # type: ignore
266+ assert results [0 ]["id_val" ] == "p1"
267+ assert results [0 ]["distance" ] == pytest .approx (1 / 60 )
268+ assert results [1 ]["id_val" ] == "p2"
269+ assert results [1 ]["distance" ] == pytest .approx (1 / 61 )
0 commit comments