Skip to content

Commit 01db41a

Browse files
committed
fix: Normalize results in wighted sum ranking
1 parent 7fee427 commit 01db41a

File tree

1 file changed

+74
-22
lines changed

1 file changed

+74
-22
lines changed

langchain_postgres/v2/hybrid_search_config.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,49 @@
44

55
from sqlalchemy import RowMapping
66

7+
from .indexes import DistanceStrategy
8+
9+
from typing import Any, Sequence
10+
11+
def _normalize_scores(
12+
results: Sequence[RowMapping], is_distance_metric: bool
13+
) -> list[dict[str, Any]]:
14+
"""Normalizes scores to a 0-1 scale, where 1 is best."""
15+
if not results:
16+
return []
17+
18+
# Get scores from the last column of each result
19+
scores = [list(item.values())[-1] for item in results]
20+
min_score, max_score = min(scores), max(scores)
21+
score_range = max_score - min_score
22+
23+
if score_range == 0:
24+
# All documents are of the highest quality (1.0)
25+
for item in results:
26+
item["normalized_score"] = 1.0
27+
return list(results)
28+
29+
for item in results:
30+
# Access the score again from the last column for calculation
31+
score = list(item.values())[-1]
32+
normalized = (score - min_score) / score_range
33+
if is_distance_metric:
34+
# For distance, a lower score is better, so we invert the result.
35+
item["normalized_score"] = 1.0 - normalized
36+
else:
37+
# For similarity (like keyword search), a higher score is better.
38+
item["normalized_score"] = normalized
39+
40+
return list(results)
41+
742

843
def weighted_sum_ranking(
944
primary_search_results: Sequence[RowMapping],
1045
secondary_search_results: Sequence[RowMapping],
1146
primary_results_weight: float = 0.5,
1247
secondary_results_weight: float = 0.5,
1348
fetch_top_k: int = 4,
49+
**kwargs: Any,
1450
) -> Sequence[dict[str, Any]]:
1551
"""
1652
Ranks documents using a weighted sum of scores from two sources.
@@ -32,35 +68,51 @@ def weighted_sum_ranking(
3268
descending order.
3369
"""
3470

71+
distance_strategy = kwargs.get("distance_strategy", DistanceStrategy.COSINE_DISTANCE)
72+
is_primary_distance = distance_strategy != DistanceStrategy.INNER_PRODUCT
73+
74+
# 1. Normalize both sets of results onto a 0-1 scale
75+
normalized_primary = _normalize_scores(
76+
[dict(row) for row in primary_search_results],
77+
is_distance_metric=is_primary_distance
78+
)
79+
80+
# Keyword search relevance is a similarity score (higher is better)
81+
normalized_secondary = _normalize_scores(
82+
[dict(row) for row in secondary_search_results],
83+
is_distance_metric=False
84+
)
85+
3586
# stores computed metric with provided distance metric and weights
3687
weighted_scores: dict[str, dict[str, Any]] = {}
3788

38-
# Process results from primary source
39-
for row in primary_search_results:
40-
values = list(row.values())
41-
doc_id = str(values[0]) # first value is doc_id
42-
distance = float(values[-1]) # type: ignore # last value is distance
43-
row_values = dict(row)
44-
row_values["distance"] = primary_results_weight * distance
45-
weighted_scores[doc_id] = row_values
46-
47-
# Process results from secondary source,
48-
# adding to existing scores or creating new ones
49-
for row in secondary_search_results:
50-
values = list(row.values())
51-
doc_id = str(values[0]) # first value is doc_id
52-
distance = float(values[-1]) # type: ignore # last value is distance
53-
primary_score = (
54-
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0
55-
)
56-
row_values = dict(row)
57-
row_values["distance"] = distance * secondary_results_weight + primary_score
58-
weighted_scores[doc_id] = row_values
89+
# Process primary results
90+
for item in normalized_primary:
91+
doc_id = str(list(item.values())[0])
92+
# Overwrite the 'distance' key with the weighted primary score
93+
item["distance"] = item["normalized_score"] * primary_results_weight
94+
weighted_scores[doc_id] = item
95+
96+
# Process secondary results
97+
for item in normalized_secondary:
98+
doc_id = str(list(item.values())[0])
99+
secondary_weighted_score = item["normalized_score"] * secondary_results_weight
100+
101+
if doc_id in weighted_scores:
102+
# Add to the existing 'distance' score
103+
weighted_scores[doc_id]["distance"] += secondary_weighted_score
104+
else:
105+
# Set the 'distance' key for the new item
106+
item["distance"] = secondary_weighted_score
107+
weighted_scores[doc_id] = item
59108

60-
# Sort the results by weighted score in descending order
61109
ranked_results = sorted(
62110
weighted_scores.values(), key=lambda item: item["distance"], reverse=True
63111
)
112+
113+
for result in ranked_results:
114+
result.pop("normalized_score", None)
115+
64116
return ranked_results[:fetch_top_k]
65117

66118

0 commit comments

Comments
 (0)