44
55from sqlalchemy import RowMapping
66
7+ from .indexes import DistanceStrategy
8+
9+ from typing import Any , Sequence
10+
11+ def _normalize_scores (
12+ results : Sequence [RowMapping ], is_distance_metric : bool
13+ ) -> list [dict [str , Any ]]:
14+ """Normalizes scores to a 0-1 scale, where 1 is best."""
15+ if not results :
16+ return []
17+
18+ # Get scores from the last column of each result
19+ scores = [list (item .values ())[- 1 ] for item in results ]
20+ min_score , max_score = min (scores ), max (scores )
21+ score_range = max_score - min_score
22+
23+ if score_range == 0 :
24+ # All documents are of the highest quality (1.0)
25+ for item in results :
26+ item ["normalized_score" ] = 1.0
27+ return list (results )
28+
29+ for item in results :
30+ # Access the score again from the last column for calculation
31+ score = list (item .values ())[- 1 ]
32+ normalized = (score - min_score ) / score_range
33+ if is_distance_metric :
34+ # For distance, a lower score is better, so we invert the result.
35+ item ["normalized_score" ] = 1.0 - normalized
36+ else :
37+ # For similarity (like keyword search), a higher score is better.
38+ item ["normalized_score" ] = normalized
39+
40+ return list (results )
41+
742
843def weighted_sum_ranking (
944 primary_search_results : Sequence [RowMapping ],
1045 secondary_search_results : Sequence [RowMapping ],
1146 primary_results_weight : float = 0.5 ,
1247 secondary_results_weight : float = 0.5 ,
1348 fetch_top_k : int = 4 ,
49+ ** kwargs : Any ,
1450) -> Sequence [dict [str , Any ]]:
1551 """
1652 Ranks documents using a weighted sum of scores from two sources.
@@ -32,35 +68,51 @@ def weighted_sum_ranking(
3268 descending order.
3369 """
3470
71+ distance_strategy = kwargs .get ("distance_strategy" , DistanceStrategy .COSINE_DISTANCE )
72+ is_primary_distance = distance_strategy != DistanceStrategy .INNER_PRODUCT
73+
74+ # 1. Normalize both sets of results onto a 0-1 scale
75+ normalized_primary = _normalize_scores (
76+ [dict (row ) for row in primary_search_results ],
77+ is_distance_metric = is_primary_distance
78+ )
79+
80+ # Keyword search relevance is a similarity score (higher is better)
81+ normalized_secondary = _normalize_scores (
82+ [dict (row ) for row in secondary_search_results ],
83+ is_distance_metric = False
84+ )
85+
3586 # stores computed metric with provided distance metric and weights
3687 weighted_scores : dict [str , dict [str , Any ]] = {}
3788
38- # Process results from primary source
39- for row in primary_search_results :
40- values = list (row .values ())
41- doc_id = str (values [0 ]) # first value is doc_id
42- distance = float (values [- 1 ]) # type: ignore # last value is distance
43- row_values = dict (row )
44- row_values ["distance" ] = primary_results_weight * distance
45- weighted_scores [doc_id ] = row_values
46-
47- # Process results from secondary source,
48- # adding to existing scores or creating new ones
49- for row in secondary_search_results :
50- values = list (row .values ())
51- doc_id = str (values [0 ]) # first value is doc_id
52- distance = float (values [- 1 ]) # type: ignore # last value is distance
53- primary_score = (
54- weighted_scores [doc_id ]["distance" ] if doc_id in weighted_scores else 0.0
55- )
56- row_values = dict (row )
57- row_values ["distance" ] = distance * secondary_results_weight + primary_score
58- weighted_scores [doc_id ] = row_values
89+ # Process primary results
90+ for item in normalized_primary :
91+ doc_id = str (list (item .values ())[0 ])
92+ # Overwrite the 'distance' key with the weighted primary score
93+ item ["distance" ] = item ["normalized_score" ] * primary_results_weight
94+ weighted_scores [doc_id ] = item
95+
96+ # Process secondary results
97+ for item in normalized_secondary :
98+ doc_id = str (list (item .values ())[0 ])
99+ secondary_weighted_score = item ["normalized_score" ] * secondary_results_weight
100+
101+ if doc_id in weighted_scores :
102+ # Add to the existing 'distance' score
103+ weighted_scores [doc_id ]["distance" ] += secondary_weighted_score
104+ else :
105+ # Set the 'distance' key for the new item
106+ item ["distance" ] = secondary_weighted_score
107+ weighted_scores [doc_id ] = item
59108
60- # Sort the results by weighted score in descending order
61109 ranked_results = sorted (
62110 weighted_scores .values (), key = lambda item : item ["distance" ], reverse = True
63111 )
112+
113+ for result in ranked_results :
114+ result .pop ("normalized_score" , None )
115+
64116 return ranked_results [:fetch_top_k ]
65117
66118
0 commit comments