Skip to content

Commit f8de351

Browse files
fix: Normalize results in wighted sum ranking (#255)
Fixes Major Issue 1 in #234 --------- Co-authored-by: Averi Kitsch <[email protected]>
1 parent 2ac41ca commit f8de351

File tree

3 files changed

+164
-70
lines changed

3 files changed

+164
-70
lines changed

langchain_postgres/v2/hybrid_search_config.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77
from .indexes import DistanceStrategy
88

99

10+
def _normalize_scores(
11+
results: Sequence[dict[str, Any]], is_distance_metric: bool
12+
) -> Sequence[dict[str, Any]]:
13+
"""Normalizes scores to a 0-1 scale, where 1 is best."""
14+
if not results:
15+
return []
16+
17+
# Get scores from the last column of each result
18+
scores = [float(list(item.values())[-1]) for item in results]
19+
min_score, max_score = min(scores), max(scores)
20+
score_range = max_score - min_score
21+
22+
if score_range == 0:
23+
# All documents are of the highest quality (1.0)
24+
for item in results:
25+
item["normalized_score"] = 1.0
26+
return list(results)
27+
28+
for item in results:
29+
# Access the score again from the last column for calculation
30+
score = list(item.values())[-1]
31+
normalized = (score - min_score) / score_range
32+
if is_distance_metric:
33+
# For distance, a lower score is better, so we invert the result.
34+
item["normalized_score"] = 1.0 - normalized
35+
else:
36+
# For similarity (like keyword search), a higher score is better.
37+
item["normalized_score"] = normalized
38+
39+
return list(results)
40+
41+
1042
def weighted_sum_ranking(
1143
primary_search_results: Sequence[RowMapping],
1244
secondary_search_results: Sequence[RowMapping],
@@ -35,35 +67,52 @@ def weighted_sum_ranking(
3567
descending order.
3668
"""
3769

70+
distance_strategy = kwargs.get(
71+
"distance_strategy", DistanceStrategy.COSINE_DISTANCE
72+
)
73+
is_primary_distance = distance_strategy != DistanceStrategy.INNER_PRODUCT
74+
75+
# Normalize both sets of results onto a 0-1 scale
76+
normalized_primary = _normalize_scores(
77+
[dict(row) for row in primary_search_results],
78+
is_distance_metric=is_primary_distance,
79+
)
80+
81+
# Keyword search relevance is a similarity score (higher is better)
82+
normalized_secondary = _normalize_scores(
83+
[dict(row) for row in secondary_search_results], is_distance_metric=False
84+
)
85+
3886
# stores computed metric with provided distance metric and weights
3987
weighted_scores: dict[str, dict[str, Any]] = {}
4088

41-
# Process results from primary source
42-
for row in primary_search_results:
43-
values = list(row.values())
44-
doc_id = str(values[0]) # first value is doc_id
45-
distance = float(values[-1]) # type: ignore # last value is distance
46-
row_values = dict(row)
47-
row_values["distance"] = primary_results_weight * distance
48-
weighted_scores[doc_id] = row_values
49-
50-
# Process results from secondary source,
51-
# adding to existing scores or creating new ones
52-
for row in secondary_search_results:
53-
values = list(row.values())
54-
doc_id = str(values[0]) # first value is doc_id
55-
distance = float(values[-1]) # type: ignore # last value is distance
56-
primary_score = (
57-
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0
58-
)
59-
row_values = dict(row)
60-
row_values["distance"] = distance * secondary_results_weight + primary_score
61-
weighted_scores[doc_id] = row_values
89+
# Process primary results
90+
for item in normalized_primary:
91+
doc_id = str(list(item.values())[0])
92+
# Set the 'distance' key with the weighted primary score
93+
item["distance"] = item["normalized_score"] * primary_results_weight
94+
weighted_scores[doc_id] = item
95+
96+
# Process secondary results
97+
for item in normalized_secondary:
98+
doc_id = str(list(item.values())[0])
99+
secondary_weighted_score = item["normalized_score"] * secondary_results_weight
100+
101+
if doc_id in weighted_scores:
102+
# Add to the existing 'distance' score
103+
weighted_scores[doc_id]["distance"] += secondary_weighted_score
104+
else:
105+
# Set the 'distance' key for the new item
106+
item["distance"] = secondary_weighted_score
107+
weighted_scores[doc_id] = item
62108

63-
# Sort the results by weighted score in descending order
64109
ranked_results = sorted(
65110
weighted_scores.values(), key=lambda item: item["distance"], reverse=True
66111
)
112+
113+
for result in ranked_results:
114+
result.pop("normalized_score", None)
115+
67116
return ranked_results[:fetch_top_k]
68117

69118

tests/unit_tests/v2/test_async_pg_vectorstore_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ async def test_hybrid_search_weighted_sum_vector_bias(
478478
result_ids = [doc.metadata["doc_id_key"] for doc in results]
479479

480480
assert len(result_ids) > 0
481-
assert result_ids[0] == "hs_doc_orange_fruit"
481+
assert result_ids[0] == "hs_doc_generic_tech"
482482

483483
async def test_hybrid_search_weighted_sum_fts_bias(
484484
self,
@@ -611,7 +611,7 @@ async def test_hybrid_search_fts_empty_results(
611611
assert len(result_ids) > 0
612612
assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids
613613
# The top result should be one of the apple documents based on vector search
614-
assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat")
614+
assert results[0].metadata["doc_id_key"].startswith("hs_doc_apple_fruit")
615615

616616
async def test_hybrid_search_vector_empty_results_effectively(
617617
self,
@@ -639,7 +639,7 @@ async def test_hybrid_search_vector_empty_results_effectively(
639639

640640
# Expect results based purely on FTS search for "orange fruit"
641641
assert len(result_ids) == 1
642-
assert result_ids[0] == "hs_doc_generic_tech"
642+
assert result_ids[0] == "hs_doc_orange_fruit"
643643

644644
async def test_hybrid_search_without_tsv_column(
645645
self,

tests/unit_tests/v2/test_hybrid_search_config.py

Lines changed: 90 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,82 +27,127 @@ def get_row(doc_id: str, score: float, content: str = "content") -> RowMapping:
2727

2828
class TestWeightedSumRanking:
2929
def test_empty_inputs(self) -> None:
30+
"""Tests that the function handles empty inputs gracefully."""
3031
results = weighted_sum_ranking([], [])
3132
assert results == []
3233

33-
def test_primary_only(self) -> None:
34+
def test_primary_only_cosine_default(self) -> None:
35+
"""Tests ranking with only primary results using default cosine distance."""
3436
primary = [get_row("p1", 0.8), get_row("p2", 0.6)]
35-
# Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3
36-
results = weighted_sum_ranking( # type: ignore
37+
# --- Calculation (Cosine = lower is better) ---
38+
# Scores: [0.8, 0.6]. Range: 0.2. Min: 0.6.
39+
# p1 norm: 1.0 - ((0.8 - 0.6) / 0.2) = 0.0
40+
# p2 norm: 1.0 - ((0.6 - 0.6) / 0.2) = 1.0
41+
# Weighted (0.5): p1 = 0.0, p2 = 0.5
42+
# Order: p2, p1
43+
results = weighted_sum_ranking(
3744
primary, # type: ignore
3845
[],
39-
primary_results_weight=0.5,
40-
secondary_results_weight=0.5,
4146
)
4247
assert len(results) == 2
43-
assert results[0]["id_val"] == "p1"
44-
assert results[0]["distance"] == pytest.approx(0.4)
45-
assert results[1]["id_val"] == "p2"
46-
assert results[1]["distance"] == pytest.approx(0.3)
48+
assert results[0]["id_val"] == "p2"
49+
assert results[0]["distance"] == pytest.approx(0.5)
50+
assert results[1]["id_val"] == "p1"
51+
assert results[1]["distance"] == pytest.approx(0.0)
4752

4853
def test_secondary_only(self) -> None:
49-
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)]
50-
# Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35
54+
"""Tests ranking with only secondary (keyword) results."""
55+
secondary = [get_row("s1", 15.0), get_row("s2", 5.0)]
56+
# --- Calculation (Keyword = higher is better) ---
57+
# Scores: [15.0, 5.0]. Range: 10.0. Min: 5.0.
58+
# s1 norm: (15.0 - 5.0) / 10.0 = 1.0
59+
# s2 norm: (5.0 - 5.0) / 10.0 = 0.0
60+
# Weighted (0.5): s1 = 0.5, s2 = 0.0
61+
# Order: s1, s2
5162
results = weighted_sum_ranking(
5263
[],
5364
secondary, # type: ignore
54-
primary_results_weight=0.5,
55-
secondary_results_weight=0.5,
5665
)
5766
assert len(results) == 2
5867
assert results[0]["id_val"] == "s1"
59-
assert results[0]["distance"] == pytest.approx(0.45)
68+
assert results[0]["distance"] == pytest.approx(0.5)
6069
assert results[1]["id_val"] == "s2"
61-
assert results[1]["distance"] == pytest.approx(0.35)
70+
assert results[1]["distance"] == pytest.approx(0.0)
6271

63-
def test_mixed_results_default_weights(self) -> None:
72+
def test_mixed_results_cosine(self) -> None:
73+
"""Tests combining cosine (lower is better) and keyword (higher is better) scores."""
6474
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
65-
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
66-
# Weights are 0.5, 0.5
67-
# common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85
68-
# p_only_score = (0.7 * 0.5) = 0.35
69-
# s_only_score = (0.6 * 0.5) = 0.30
70-
# Order: common (0.85), p_only (0.35), s_only (0.30)
71-
72-
results = weighted_sum_ranking(primary, secondary) # type: ignore
75+
secondary = [get_row("common", 9.0), get_row("s_only", 6.0)]
76+
# --- Calculation ---
77+
# Primary norm (inverted): common=0.0, p_only=1.0
78+
# Secondary norm: common=1.0, s_only=0.0
79+
# Weighted (0.5):
80+
# common = (0.0 * 0.5) + (1.0 * 0.5) = 0.5
81+
# p_only = (1.0 * 0.5) + 0 = 0.5
82+
# s_only = 0 + (0.0 * 0.5) = 0.0
83+
results = weighted_sum_ranking(
84+
primary, # type: ignore
85+
secondary, # type: ignore
86+
)
7387
assert len(results) == 3
74-
assert results[0]["id_val"] == "common"
75-
assert results[0]["distance"] == pytest.approx(0.85)
76-
assert results[1]["id_val"] == "p_only"
77-
assert results[1]["distance"] == pytest.approx(0.35)
88+
# Check that the top two results have the correct score and IDs (order may vary)
89+
top_ids = {res["id_val"] for res in results[:2]}
90+
assert top_ids == {"common", "p_only"}
91+
assert results[0]["distance"] == pytest.approx(0.5)
92+
assert results[1]["distance"] == pytest.approx(0.5)
7893
assert results[2]["id_val"] == "s_only"
79-
assert results[2]["distance"] == pytest.approx(0.30)
94+
assert results[2]["distance"] == pytest.approx(0.0)
8095

81-
def test_mixed_results_custom_weights(self) -> None:
82-
primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2
83-
secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4
84-
# Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6
96+
def test_primary_max_inner_product(self) -> None:
97+
"""Tests using MAX_INNER_PRODUCT (higher is better) for primary search."""
98+
primary = [get_row("best", 0.9), get_row("worst", 0.1)]
99+
secondary = [get_row("best", 20.0), get_row("worst", 5.0)]
100+
# --- Calculation ---
101+
# Primary norm (NOT inverted): best=1.0, worst=0.0
102+
# Secondary norm: best=1.0, worst=0.0
103+
# Weighted (0.5):
104+
# best = (1.0 * 0.5) + (1.0 * 0.5) = 1.0
105+
# worst = (0.0 * 0.5) + (0.0 * 0.5) = 0.0
106+
results = weighted_sum_ranking(
107+
primary, # type: ignore
108+
secondary, # type: ignore
109+
distance_strategy=DistanceStrategy.INNER_PRODUCT,
110+
)
111+
assert len(results) == 2
112+
assert results[0]["id_val"] == "best"
113+
assert results[0]["distance"] == pytest.approx(1.0)
114+
assert results[1]["id_val"] == "worst"
115+
assert results[1]["distance"] == pytest.approx(0.0)
85116

117+
def test_primary_euclidean(self) -> None:
118+
"""Tests using EUCLIDEAN (lower is better) for primary search."""
119+
primary = [get_row("closer", 10.5), get_row("farther", 25.5)]
120+
secondary = [get_row("closer", 100.0), get_row("farther", 10.0)]
121+
# --- Calculation ---
122+
# Primary norm (inverted): closer=1.0, farther=0.0
123+
# Secondary norm: closer=1.0, farther=0.0
124+
# Weighted (0.5):
125+
# closer = (1.0 * 0.5) + (1.0 * 0.5) = 1.0
126+
# farther = (0.0 * 0.5) + (0.0 * 0.5) = 0.0
86127
results = weighted_sum_ranking(
87128
primary, # type: ignore
88129
secondary, # type: ignore
89-
primary_results_weight=0.2,
90-
secondary_results_weight=0.8,
130+
distance_strategy=DistanceStrategy.EUCLIDEAN,
91131
)
92-
assert len(results) == 1
93-
assert results[0]["id_val"] == "d1"
94-
assert results[0]["distance"] == pytest.approx(0.6)
132+
assert len(results) == 2
133+
assert results[0]["id_val"] == "closer"
134+
assert results[0]["distance"] == pytest.approx(1.0)
135+
assert results[1]["id_val"] == "farther"
136+
assert results[1]["distance"] == pytest.approx(0.0)
95137

96138
def test_fetch_top_k(self) -> None:
139+
"""Tests that fetch_top_k correctly limits the number of results."""
97140
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
98-
# Scores: 1.0, 0.9, 0.8, 0.7, 0.6
99-
# Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3
100-
results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore
141+
# p0=1.0, p1=0.9, p2=0.8, p3=0.7, p4=0.6
142+
# The best scores (lowest distance) are p4 and p3
143+
results = weighted_sum_ranking(
144+
primary, # type: ignore
145+
[],
146+
fetch_top_k=2,
147+
)
101148
assert len(results) == 2
102-
assert results[0]["id_val"] == "p0"
103-
assert results[0]["distance"] == pytest.approx(0.5)
104-
assert results[1]["id_val"] == "p1"
105-
assert results[1]["distance"] == pytest.approx(0.45)
149+
assert results[0]["id_val"] == "p4" # Has the best normalized score
150+
assert results[1]["id_val"] == "p3"
106151

107152

108153
class TestReciprocalRankFusion:

0 commit comments

Comments
 (0)