Skip to content

Commit 2ac41ca

Browse files
fix: Normalize scores in RRF ranking in hybrid search (#256)
Fixes Major Issue 1 in #234 --------- Co-authored-by: Averi Kitsch <[email protected]>
1 parent 70ba3f2 commit 2ac41ca

File tree

3 files changed

+165
-110
lines changed

3 files changed

+165
-110
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ async def __query_collection(
670670
dense_results,
671671
sparse_results,
672672
**hybrid_search_config.fusion_function_parameters,
673+
distance_strategy=self.distance_strategy,
673674
)
674675
return combined_results
675676
return dense_results

langchain_postgres/v2/hybrid_search_config.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44

55
from sqlalchemy import RowMapping
66

7+
from .indexes import DistanceStrategy
8+
79

810
def weighted_sum_ranking(
911
primary_search_results: Sequence[RowMapping],
1012
secondary_search_results: Sequence[RowMapping],
1113
primary_results_weight: float = 0.5,
1214
secondary_results_weight: float = 0.5,
1315
fetch_top_k: int = 4,
16+
**kwargs: Any,
1417
) -> Sequence[dict[str, Any]]:
1518
"""
1619
Ranks documents using a weighted sum of scores from two sources.
@@ -69,6 +72,7 @@ def reciprocal_rank_fusion(
6972
secondary_search_results: Sequence[RowMapping],
7073
rrf_k: float = 60,
7174
fetch_top_k: int = 4,
75+
**kwargs: Any,
7276
) -> Sequence[dict[str, Any]]:
7377
"""
7478
Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources.
@@ -87,35 +91,45 @@ def reciprocal_rank_fusion(
8791
A list of (document_id, rrf_score) tuples, sorted by rrf_score
8892
in descending order.
8993
"""
94+
distance_strategy = kwargs.get(
95+
"distance_strategy", DistanceStrategy.COSINE_DISTANCE
96+
)
9097
rrf_scores: dict[str, dict[str, Any]] = {}
9198

9299
# Process results from primary source
93-
for rank, row in enumerate(
94-
sorted(primary_search_results, key=lambda item: item["distance"], reverse=True)
95-
):
96-
values = list(row.values())
97-
doc_id = str(values[0])
98-
row_values = dict(row)
99-
primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
100-
primary_score += 1.0 / (rank + rrf_k)
101-
row_values["distance"] = primary_score
102-
rrf_scores[doc_id] = row_values
100+
# Determine sorting order based on the vector distance strategy.
101+
# For COSINE & EUCLIDEAN(distance), we sort ascending (reverse=False).
102+
# For INNER_PRODUCT (similarity), we sort descending (reverse=True).
103+
is_similarity_metric = distance_strategy == DistanceStrategy.INNER_PRODUCT
104+
sorted_primary = sorted(
105+
primary_search_results,
106+
key=lambda item: item["distance"],
107+
reverse=is_similarity_metric,
108+
)
109+
110+
for rank, row in enumerate(sorted_primary):
111+
doc_id = str(list(row.values())[0])
112+
if doc_id not in rrf_scores:
113+
rrf_scores[doc_id] = dict(row)
114+
rrf_scores[doc_id]["distance"] = 0.0
115+
# Add the "normalized" rank score
116+
rrf_scores[doc_id]["distance"] += 1.0 / (rank + rrf_k)
103117

104118
# Process results from secondary source
105-
for rank, row in enumerate(
106-
sorted(
107-
secondary_search_results, key=lambda item: item["distance"], reverse=True
108-
)
109-
):
110-
values = list(row.values())
111-
doc_id = str(values[0])
112-
row_values = dict(row)
113-
secondary_score = (
114-
rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
115-
)
116-
secondary_score += 1.0 / (rank + rrf_k)
117-
row_values["distance"] = secondary_score
118-
rrf_scores[doc_id] = row_values
119+
# Keyword search relevance is always "higher is better" -> sort descending
120+
sorted_secondary = sorted(
121+
secondary_search_results,
122+
key=lambda item: item["distance"],
123+
reverse=True,
124+
)
125+
126+
for rank, row in enumerate(sorted_secondary):
127+
doc_id = str(list(row.values())[0])
128+
if doc_id not in rrf_scores:
129+
rrf_scores[doc_id] = dict(row)
130+
rrf_scores[doc_id]["distance"] = 0.0
131+
# Add the rank score from this list to the existing score
132+
rrf_scores[doc_id]["distance"] += 1.0 / (rank + rrf_k)
119133

120134
# Sort the results by rrf score in descending order
121135
# Sort the results by weighted score in descending order
Lines changed: 126 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from typing import cast
2+
13
import pytest
4+
from sqlalchemy import RowMapping
25

36
from langchain_postgres.v2.hybrid_search_config import (
47
reciprocal_rank_fusion,
58
weighted_sum_ranking,
69
)
10+
from langchain_postgres.v2.indexes import DistanceStrategy
711

812

913
# Helper to create mock input items that mimic RowMapping for the fusion functions
10-
def get_row(doc_id: str, score: float, content: str = "content") -> dict:
14+
def get_row(doc_id: str, score: float, content: str = "content") -> RowMapping:
1115
"""
1216
Simulates a RowMapping-like dictionary.
1317
The fusion functions expect to extract doc_id as the first value and
@@ -17,7 +21,8 @@ def get_row(doc_id: str, score: float, content: str = "content") -> dict:
1721
# Python dicts maintain insertion order (Python 3.7+).
1822
# This structure ensures list(row.values())[0] is doc_id and
1923
# list(row.values())[-1] is score.
20-
return {"id_val": doc_id, "content_field": content, "distance": score}
24+
row_dict = {"id_val": doc_id, "content_field": content, "distance": score}
25+
return cast(RowMapping, row_dict)
2126

2227

2328
class TestWeightedSumRanking:
@@ -102,30 +107,31 @@ def test_fetch_top_k(self) -> None:
102107

103108
class TestReciprocalRankFusion:
104109
def test_empty_inputs(self) -> None:
110+
"""Tests that the function handles empty inputs gracefully."""
105111
results = reciprocal_rank_fusion([], [])
106112
assert results == []
107113

108114
def test_primary_only(self) -> None:
109-
primary = [
110-
get_row("p1", 0.8),
111-
get_row("p2", 0.6),
112-
] # p1 rank 0, p2 rank 1
115+
"""Tests RRF with only primary results using default cosine (lower is better)."""
116+
primary = [get_row("p1", 0.8), get_row("p2", 0.6)]
113117
rrf_k = 60
114-
# p1_score = 1 / (0 + 60)
115-
# p2_score = 1 / (1 + 60)
118+
# --- Calculation (Cosine: lower is better) ---
119+
# Sorted order: p2 (0.6) -> rank 0; p1 (0.8) -> rank 1
120+
# p2_score = 1 / (0 + 60)
121+
# p1_score = 1 / (1 + 60)
116122
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore
117123
assert len(results) == 2
118-
assert results[0]["id_val"] == "p1"
124+
assert results[0]["id_val"] == "p2"
119125
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
120-
assert results[1]["id_val"] == "p2"
126+
assert results[1]["id_val"] == "p1"
121127
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
122128

123129
def test_secondary_only(self) -> None:
124-
secondary = [
125-
get_row("s1", 0.9),
126-
get_row("s2", 0.7),
127-
] # s1 rank 0, s2 rank 1
130+
"""Tests RRF with only secondary results (higher is better)."""
131+
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)]
128132
rrf_k = 60
133+
# --- Calculation (Keyword: higher is better) ---
134+
# Sorted order: s1 (0.9) -> rank 0; s2 (0.7) -> rank 1
129135
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore
130136
assert len(results) == 2
131137
assert results[0]["id_val"] == "s1"
@@ -134,96 +140,130 @@ def test_secondary_only(self) -> None:
134140
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
135141

136142
def test_mixed_results_default_k(self) -> None:
137-
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
138-
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
143+
"""Tests fusion with default cosine (lower better) and keyword (higher better)."""
144+
primary = [
145+
get_row("common", 0.8),
146+
get_row("p_only", 0.7),
147+
] # Order: p_only, common
148+
secondary = [
149+
get_row("common", 0.9),
150+
get_row("s_only", 0.6),
151+
] # Order: common, s_only
139152
rrf_k = 60
140-
# common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k
141-
# p_only_score = (1/(1+k))_prim = 1/(k+1)
142-
# s_only_score = (1/(1+k))_sec = 1/(k+1)
153+
# --- Calculation ---
154+
# common: rank 1 in P (1/61) + rank 0 in S (1/60) -> highest score
155+
# p_only: rank 0 in P (1/60)
156+
# s_only: rank 1 in S (1/61)
143157
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
144158
assert len(results) == 3
145159
assert results[0]["id_val"] == "common"
146-
assert results[0]["distance"] == pytest.approx(2.0 / rrf_k)
147-
# Check the next two elements, their order might vary due to tie in score
148-
next_ids = {results[1]["id_val"], results[2]["id_val"]}
149-
next_scores = {results[1]["distance"], results[2]["distance"]}
150-
assert next_ids == {"p_only", "s_only"}
151-
for score in next_scores:
152-
assert score == pytest.approx(1.0 / (1 + rrf_k))
160+
assert results[0]["distance"] == pytest.approx(1 / 61 + 1 / 60)
161+
assert results[1]["id_val"] == "p_only"
162+
assert results[1]["distance"] == pytest.approx(1 / 60)
163+
assert results[2]["id_val"] == "s_only"
164+
assert results[2]["distance"] == pytest.approx(1 / 61)
153165

154166
def test_fetch_top_k_rrf(self) -> None:
167+
"""Tests that fetch_top_k limits results correctly after fusion."""
168+
# Using cosine distance (lower is better)
155169
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
156-
rrf_k = 1
157-
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore
170+
# Scores: [1.0, 0.9, 0.8, 0.7, 0.6]
171+
# Sorted order: p4 (0.6), p3 (0.7), p2 (0.8), ...
172+
results = reciprocal_rank_fusion(primary, [], fetch_top_k=2) # type: ignore
158173
assert len(results) == 2
159-
assert results[0]["id_val"] == "p0"
160-
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
161-
assert results[1]["id_val"] == "p1"
162-
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
174+
assert results[0]["id_val"] == "p4"
175+
assert results[1]["id_val"] == "p3"
163176

164177
def test_rrf_content_preservation(self) -> None:
178+
"""Tests that the data from the first time a document is seen is kept."""
165179
primary = [get_row("doc1", 0.9, content="Primary Content")]
166180
secondary = [get_row("doc1", 0.8, content="Secondary Content")]
167-
# RRF processes primary then secondary. If a doc is in both,
168-
# the content from the secondary list will overwrite primary's.
169-
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore
181+
# RRF processes primary first. When "doc1" is seen, its data is stored.
182+
# It will not be overwritten by the "doc1" from the secondary list.
183+
results = reciprocal_rank_fusion(primary, secondary) # type: ignore
170184
assert len(results) == 1
171185
assert results[0]["id_val"] == "doc1"
172-
assert results[0]["content_field"] == "Secondary Content"
186+
assert results[0]["content_field"] == "Primary Content"
173187

174-
# If only in primary
175-
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore
176-
assert results_prim_only[0]["content_field"] == "Primary Content"
188+
# If only in secondary
189+
results_prim_only = reciprocal_rank_fusion([], secondary, rrf_k=60) # type: ignore
190+
assert results_prim_only[0]["content_field"] == "Secondary Content"
177191

178192
def test_reordering_from_inputs_rrf(self) -> None:
179-
"""
180-
Tests that RRF fused ranking can be different from both primary and secondary
181-
input rankings.
182-
Primary Order: A, B, C
183-
Secondary Order: C, B, A
184-
Fused Order: (A, C) tied, then B
185-
"""
186-
primary = [
187-
get_row("docA", 0.9),
188-
get_row("docB", 0.8),
189-
get_row("docC", 0.1),
190-
]
191-
secondary = [
192-
get_row("docC", 0.9),
193-
get_row("docB", 0.5),
194-
get_row("docA", 0.2),
195-
]
196-
rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation
197-
# docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3
198-
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1
199-
# docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3
193+
"""Tests that RRF can produce a ranking different from the inputs."""
194+
primary = [get_row("docA", 0.9), get_row("docB", 0.8), get_row("docC", 0.1)]
195+
secondary = [get_row("docC", 0.9), get_row("docB", 0.5), get_row("docA", 0.2)]
196+
rrf_k = 1.0
197+
# --- Calculation (Primary sorted ascending, Secondary descending) ---
198+
# Primary ranks: docC (0), docB (1), docA (2)
199+
# Secondary ranks: docC (0), docB (1), docA (2)
200+
# docC_score = 1/(0+1) [P] + 1/(0+1) [S] = 2.0
201+
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1.0
202+
# docA_score = 1/(2+1) [P] + 1/(2+1) [S] = 2/3
200203
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
201204
assert len(results) == 3
202-
assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"}
203-
assert results[0]["distance"] == pytest.approx(4.0 / 3.0)
204-
assert results[1]["distance"] == pytest.approx(4.0 / 3.0)
205-
assert results[2]["id_val"] == "docB"
206-
assert results[2]["distance"] == pytest.approx(1.0)
207-
208-
def test_reordering_from_inputs_weighted_sum(self) -> None:
209-
"""
210-
Tests that the fused ranking can be different from both primary and secondary
211-
input rankings.
212-
Primary Order: A (0.9), B (0.7)
213-
Secondary Order: B (0.8), A (0.2)
214-
Fusion (0.5/0.5 weights):
215-
docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55
216-
docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75
217-
Expected Fused Order: docB (0.75), docA (0.55)
218-
This is different from Primary (A,B) and Secondary (B,A) in terms of
219-
original score, but the fusion logic changes the effective contribution).
220-
"""
221-
primary = [get_row("docA", 0.9), get_row("docB", 0.7)]
222-
secondary = [get_row("docB", 0.8), get_row("docA", 0.2)]
205+
assert results[0]["id_val"] == "docC"
206+
assert results[0]["distance"] == pytest.approx(2.0)
207+
assert results[1]["id_val"] == "docB"
208+
assert results[1]["distance"] == pytest.approx(1.0)
209+
assert results[2]["id_val"] == "docA"
210+
assert results[2]["distance"] == pytest.approx(2.0 / 3.0)
223211

224-
results = weighted_sum_ranking(primary, secondary) # type: ignore
212+
# --------------------------------------------------------------------------
213+
## New Tests for Other Strategies and Edge Cases
214+
215+
def test_mixed_results_max_inner_product(self) -> None:
216+
"""Tests fusion with MAX_INNER_PRODUCT (higher is better) for primary."""
217+
primary = [get_row("best", 0.9), get_row("worst", 0.1)] # Order: best, worst
218+
secondary = [get_row("best", 20.0), get_row("worst", 5.0)] # Order: best, worst
219+
rrf_k = 10
220+
# best: rank 0 in P + rank 0 in S -> 1/10 + 1/10 = 0.2
221+
# worst: rank 1 in P + rank 1 in S -> 1/11 + 1/11
222+
results = reciprocal_rank_fusion(
223+
primary, # type: ignore
224+
secondary, # type: ignore
225+
rrf_k=rrf_k,
226+
distance_strategy=DistanceStrategy.INNER_PRODUCT,
227+
)
228+
assert len(results) == 2
229+
assert results[0]["id_val"] == "best"
230+
assert results[0]["distance"] == pytest.approx(0.2)
231+
assert results[1]["id_val"] == "worst"
232+
assert results[1]["distance"] == pytest.approx(2.0 / 11.0)
233+
234+
def test_mixed_results_euclidean(self) -> None:
235+
"""Tests fusion with EUCLIDEAN (lower is better) for primary."""
236+
primary = [
237+
get_row("closer", 10.5),
238+
get_row("farther", 25.5),
239+
] # Order: closer, farther
240+
secondary = [
241+
get_row("closer", 100.0),
242+
get_row("farther", 10.0),
243+
] # Order: closer, farther
244+
rrf_k = 10
245+
# closer: rank 0 in P + rank 0 in S -> 1/10 + 1/10 = 0.2
246+
# farther: rank 1 in P + rank 1 in S -> 1/11 + 1/11
247+
results = reciprocal_rank_fusion(
248+
primary, # type: ignore
249+
secondary, # type: ignore
250+
rrf_k=rrf_k,
251+
distance_strategy=DistanceStrategy.EUCLIDEAN,
252+
)
225253
assert len(results) == 2
226-
assert results[0]["id_val"] == "docB"
227-
assert results[0]["distance"] == pytest.approx(0.75)
228-
assert results[1]["id_val"] == "docA"
229-
assert results[1]["distance"] == pytest.approx(0.55)
254+
assert results[0]["id_val"] == "closer"
255+
assert results[0]["distance"] == pytest.approx(0.2)
256+
assert results[1]["id_val"] == "farther"
257+
assert results[1]["distance"] == pytest.approx(2.0 / 11.0)
258+
259+
def test_rrf_with_identical_scores(self) -> None:
260+
"""Tests that stable sort is preserved for identical scores."""
261+
# Python's sorted() is stable. p1 appears before p2 in the list.
262+
primary = [get_row("p1", 0.5), get_row("p2", 0.5)]
263+
rrf_k = 60
264+
# Expected order (stable sort): p1 (rank 0), p2 (rank 1)
265+
results = reciprocal_rank_fusion(primary, []) # type: ignore
266+
assert results[0]["id_val"] == "p1"
267+
assert results[0]["distance"] == pytest.approx(1 / 60)
268+
assert results[1]["id_val"] == "p2"
269+
assert results[1]["distance"] == pytest.approx(1 / 61)

0 commit comments

Comments
 (0)