@@ -10,13 +10,10 @@ use function_name::named;
10
10
use futures:: StreamExt ;
11
11
use futures:: stream:: FuturesUnordered ;
12
12
use reqwest:: StatusCode ;
13
- use std:: cmp:: Reverse ;
14
13
use std:: collections:: HashMap ;
15
- use std:: collections:: HashSet ;
16
14
use std:: sync:: Arc ;
17
15
use tauri:: { AppHandle , Manager } ;
18
16
use tokio:: time:: { Duration , timeout} ;
19
-
20
17
#[ named]
21
18
#[ tauri:: command]
22
19
pub async fn query_coco_fusion (
@@ -187,7 +184,6 @@ async fn query_coco_fusion_multi_query_sources(
187
184
188
185
let mut futures = FuturesUnordered :: new ( ) ;
189
186
190
- let query_source_list_len = query_source_trait_object_list. len ( ) ;
191
187
for query_source_trait_object in query_source_trait_object_list {
192
188
let query_source = query_source_trait_object. get_type ( ) . clone ( ) ;
193
189
let tauri_app_handle_clone = tauri_app_handle. clone ( ) ;
@@ -208,14 +204,8 @@ async fn query_coco_fusion_multi_query_sources(
208
204
}
209
205
210
206
let mut total_hits = 0 ;
211
- let mut need_rerank = true ; //TODO set default to false when boost supported in Pizza
212
207
let mut failed_requests = Vec :: new ( ) ;
213
- let mut all_hits: Vec < ( String , QueryHits , f64 ) > = Vec :: new ( ) ;
214
- let mut hits_per_source: HashMap < String , Vec < ( QueryHits , f64 ) > > = HashMap :: new ( ) ;
215
-
216
- if query_source_list_len > 1 {
217
- need_rerank = true ; // If we have more than one source, we need to rerank the hits
218
- }
208
+ let mut all_hits_grouped_by_source_id: HashMap < String , Vec < QueryHits > > = HashMap :: new ( ) ;
219
209
220
210
while let Some ( ( query_source, timeout_result) ) = futures. next ( ) . await {
221
211
match timeout_result {
@@ -246,12 +236,10 @@ async fn query_coco_fusion_multi_query_sources(
246
236
document,
247
237
} ;
248
238
249
- all_hits. push ( ( source_id. clone ( ) , query_hit. clone ( ) , score) ) ;
250
-
251
- hits_per_source
239
+ all_hits_grouped_by_source_id
252
240
. entry ( source_id. clone ( ) )
253
241
. or_insert_with ( Vec :: new)
254
- . push ( ( query_hit, score ) ) ;
242
+ . push ( query_hit) ;
255
243
}
256
244
}
257
245
Err ( search_error) => {
@@ -267,109 +255,117 @@ async fn query_coco_fusion_multi_query_sources(
267
255
}
268
256
}
269
257
270
- // Sort hits within each source by score (descending)
271
- for hits in hits_per_source. values_mut ( ) {
272
- hits. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Greater ) ) ;
273
- }
274
-
275
- let total_sources = hits_per_source. len ( ) ;
276
- let max_hits_per_source = if total_sources > 0 {
277
- size as usize / total_sources
278
- } else {
279
- size as usize
280
- } ;
258
+ let n_sources = all_hits_grouped_by_source_id. len ( ) ;
281
259
282
- let mut final_hits = Vec :: new ( ) ;
283
- let mut seen_docs = HashSet :: new ( ) ; // To track documents we've already added
284
-
285
- // Distribute hits fairly across sources
286
- for ( _source_id, hits) in & mut hits_per_source {
287
- let take_count = hits. len ( ) . min ( max_hits_per_source) ;
288
- for ( doc, score) in hits. drain ( 0 ..take_count) {
289
- if !seen_docs. contains ( & doc. document . id ) {
290
- seen_docs. insert ( doc. document . id . clone ( ) ) ;
291
- log:: debug!(
292
- "collect doc: {}, {:?}, {}" ,
293
- doc. document. id,
294
- doc. document. title,
295
- score
296
- ) ;
297
- final_hits. push ( doc) ;
298
- }
299
- }
260
+ if n_sources == 0 {
261
+ return Ok ( MultiSourceQueryResponse {
262
+ failed : Vec :: new ( ) ,
263
+ hits : Vec :: new ( ) ,
264
+ total_hits : 0 ,
265
+ } ) ;
300
266
}
301
267
302
- log:: debug!( "final hits: {:?}" , final_hits. len( ) ) ;
303
-
304
- let mut unique_sources = HashSet :: new ( ) ;
305
- for hit in & final_hits {
306
- if let Some ( source) = & hit. source {
307
- if source. id != crate :: extension:: built_in:: calculator:: DATA_SOURCE_ID {
308
- unique_sources. insert ( & source. id ) ;
309
- }
310
- }
268
+ /*
269
+ * Re-rank the hits
270
+ */
271
+ if n_sources > 1 {
272
+ boosted_levenshtein_rerank ( & query_keyword, & mut all_hits_grouped_by_source_id) ;
311
273
}
312
274
313
- log:: debug!(
314
- "Multiple sources found: {:?}, no rerank needed" ,
315
- unique_sources
316
- ) ;
317
-
318
- if unique_sources. len ( ) < 1 {
319
- need_rerank = false ; // If we have hits from multiple sources, we don't need to rerank
275
+ /*
276
+ * Sort hits within each source by score (descending) in case data sources
277
+ * do not sort them
278
+ */
279
+ for hits in all_hits_grouped_by_source_id. values_mut ( ) {
280
+ hits. sort_by ( |a, b| {
281
+ b. score
282
+ . partial_cmp ( & a. score )
283
+ . unwrap_or ( std:: cmp:: Ordering :: Greater )
284
+ } ) ;
320
285
}
321
286
322
- if need_rerank && final_hits. len ( ) > 1 {
323
- // Precollect (index, title)
324
- let titles_to_score: Vec < ( usize , & str ) > = final_hits
325
- . iter ( )
326
- . enumerate ( )
327
- . filter_map ( |( idx, hit) | {
328
- let source = hit. source . as_ref ( ) ?;
329
- let title = hit. document . title . as_deref ( ) ?;
287
+ /*
288
+ * Collect hits evenly across sources, to ensure:
289
+ *
290
+ * 1. All sources have hits returned
291
+ * 2. Query sources with many hits won't dominate
292
+ */
293
+ let mut final_hits_grouped_by_source_id: HashMap < String , Vec < QueryHits > > = HashMap :: new ( ) ;
294
+ let mut pruned: HashMap < & str , & [ QueryHits ] > = HashMap :: new ( ) ;
295
+
296
+ // max_hits_per_source could be 0, then `final_hits_grouped_by_source_id`
297
+ // would be empty. But we don't need to worry about this case as we will
298
+ // populate hits later.
299
+ let max_hits_per_source = size as usize / n_sources;
300
+ for ( source_id, hits) in all_hits_grouped_by_source_id. iter ( ) {
301
+ let hits_taken = if hits. len ( ) > max_hits_per_source {
302
+ pruned. insert ( & source_id, & hits[ max_hits_per_source..] ) ;
303
+ hits[ 0 ..max_hits_per_source] . to_vec ( )
304
+ } else {
305
+ hits. clone ( )
306
+ } ;
307
+
308
+ final_hits_grouped_by_source_id. insert ( source_id. clone ( ) , hits_taken) ;
309
+ }
330
310
331
- if source. id != crate :: extension:: built_in:: calculator:: DATA_SOURCE_ID {
332
- Some ( ( idx, title) )
333
- } else {
334
- None
311
+ let final_hits_len = final_hits_grouped_by_source_id
312
+ . iter ( )
313
+ . fold ( 0 , |acc : usize , ( _source_id, hits) | acc + hits. len ( ) ) ;
314
+ let pruned_len = pruned
315
+ . iter ( )
316
+ . fold ( 0 , |acc : usize , ( _source_id, hits) | acc + hits. len ( ) ) ;
317
+
318
+ /*
319
+ * If we still need more hits, take the highest-scoring from `pruned`
320
+ *
321
+ * `pruned` contains sorted arrays, we scan it in a way similar to
322
+ * how n-way-merge-sort extracts the element with the greatest value.
323
+ */
324
+ if final_hits_len < size as usize {
325
+ let n_need = size as usize - final_hits_len;
326
+ let n_have = pruned_len;
327
+ let n_take = n_have. min ( n_need) ;
328
+
329
+ for _ in 0 ..n_take {
330
+ let mut highest_score_hit: Option < ( & str , & QueryHits ) > = None ;
331
+ for ( source_id, sorted_hits) in pruned. iter_mut ( ) {
332
+ if sorted_hits. is_empty ( ) {
333
+ continue ;
335
334
}
336
- } )
337
- . collect ( ) ;
338
335
339
- // Score them
340
- let scored_hits = boosted_levenshtein_rerank ( query_keyword. as_str ( ) , titles_to_score) ;
336
+ let hit = & sorted_hits[ 0 ] ;
341
337
342
- // Sort descending by score
343
- let mut scored_hits = scored_hits;
344
- scored_hits. sort_by_key ( |& ( _, score) | Reverse ( ( score * 1000.0 ) as u64 ) ) ;
338
+ let have_higher_score_hit = match highest_score_hit {
339
+ Some ( ( _, current_highest_score_hit) ) => {
340
+ hit. score > current_highest_score_hit. score
341
+ }
342
+ None => true ,
343
+ } ;
345
344
346
- // Apply new scores to final_hits
347
- for ( idx, score) in scored_hits. into_iter ( ) . take ( size as usize ) {
348
- final_hits[ idx] . score = score;
349
- }
350
- } else if final_hits. len ( ) < size as usize {
351
- // If we still need more hits, take the highest-scoring remaining ones
345
+ if have_higher_score_hit {
346
+ highest_score_hit = Some ( ( * source_id, hit) ) ;
352
347
353
- let remaining_needed = size as usize - final_hits. len ( ) ;
348
+ // Advance sorted_hits by 1 element, if have
349
+ if sorted_hits. len ( ) == 1 {
350
+ * sorted_hits = & [ ] ;
351
+ } else {
352
+ * sorted_hits = & sorted_hits[ 1 ..] ;
353
+ }
354
+ }
355
+ }
354
356
355
- // Sort all hits by score descending, removing duplicates by document ID
356
- all_hits. sort_by ( |a, b| b. 2 . partial_cmp ( & a. 2 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
357
+ let ( source_id, hit) = highest_score_hit. expect ( "`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set" ) ;
357
358
358
- let extra_hits = all_hits
359
- . into_iter ( )
360
- . filter ( |( source_id, _, _) | hits_per_source. contains_key ( source_id) ) // Only take from known sources
361
- . filter_map ( |( _, doc, _) | {
362
- if !seen_docs. contains ( & doc. document . id ) {
363
- seen_docs. insert ( doc. document . id . clone ( ) ) ;
364
- Some ( doc)
365
- } else {
366
- None
367
- }
368
- } )
369
- . take ( remaining_needed)
370
- . collect :: < Vec < _ > > ( ) ;
359
+ final_hits_grouped_by_source_id
360
+ . get_mut ( source_id)
361
+ . expect ( "all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist" )
362
+ . push ( hit. clone ( ) ) ;
363
+ }
364
+ }
371
365
372
- final_hits. extend ( extra_hits) ;
366
+ let mut final_hits = Vec :: new ( ) ;
367
+ for ( _source_id, hits) in final_hits_grouped_by_source_id {
368
+ final_hits. extend ( hits) ;
373
369
}
374
370
375
371
// **Sort final hits by score descending**
@@ -379,6 +375,11 @@ async fn query_coco_fusion_multi_query_sources(
379
375
. unwrap_or ( std:: cmp:: Ordering :: Equal )
380
376
} ) ;
381
377
378
+ // Truncate `final_hits` in case it contains more than `size` hits
379
+ //
380
+ // Technically, we are safe to not do this. But since it is trivial, double-check it.
381
+ final_hits. truncate ( size as usize ) ;
382
+
382
383
if final_hits. len ( ) < 5 {
383
384
//TODO: Add a recommendation system to suggest more sources
384
385
log:: info!(
@@ -395,30 +396,85 @@ async fn query_coco_fusion_multi_query_sources(
395
396
} )
396
397
}
397
398
398
- fn boosted_levenshtein_rerank ( query : & str , titles : Vec < ( usize , & str ) > ) -> Vec < ( usize , f64 ) > {
399
- use strsim:: levenshtein;
399
+ use std :: collections :: HashSet ;
400
+ use strsim:: levenshtein;
400
401
402
+ fn boosted_levenshtein_rerank (
403
+ query : & str ,
404
+ all_hits_grouped_by_source_id : & mut HashMap < String , Vec < QueryHits > > ,
405
+ ) {
401
406
let query_lower = query. to_lowercase ( ) ;
402
407
403
- titles
404
- . into_iter ( )
405
- . map ( |( idx, title) | {
406
- let mut score = 0.0 ;
408
+ for ( source_id, hits) in all_hits_grouped_by_source_id. iter_mut ( ) {
409
+ // Skip special sources like calculator
410
+ if source_id == crate :: extension:: built_in:: calculator:: DATA_SOURCE_ID {
411
+ continue ;
412
+ }
413
+
414
+ for hit in hits. iter_mut ( ) {
415
+ let document_title = hit. document . title . as_deref ( ) . unwrap_or ( "" ) ;
416
+ let document_title_lowercase = document_title. to_lowercase ( ) ;
407
417
408
- if title. contains ( query) {
409
- score += 0.4 ;
410
- } else if title. to_lowercase ( ) . contains ( & query_lower) {
411
- score += 0.2 ;
412
- }
418
+ let new_score = {
419
+ let mut score = 0.0 ;
413
420
414
- let dist = levenshtein ( & query_lower, & title. to_lowercase ( ) ) ;
415
- let max_len = query_lower. len ( ) . max ( title. len ( ) ) ;
416
- if max_len > 0 {
417
- score += ( 1.0 - ( dist as f64 / max_len as f64 ) ) as f32 ;
418
- }
421
+ // --- Exact or substring boost ---
422
+ if document_title. contains ( query) {
423
+ score += 0.4 ;
424
+ } else if document_title_lowercase. contains ( & query_lower) {
425
+ score += 0.2 ;
426
+ }
427
+
428
+ // --- Levenshtein distance (character similarity) ---
429
+ let dist = levenshtein ( & query_lower, & document_title_lowercase) ;
430
+ let max_len = query_lower. len ( ) . max ( document_title. len ( ) ) ;
431
+ let levenshtein_score = if max_len > 0 {
432
+ ( 1.0 - ( dist as f64 / max_len as f64 ) ) as f32
433
+ } else {
434
+ 0.0
435
+ } ;
436
+
437
+ // --- Jaccard similarity (token overlap) ---
438
+ let jaccard_score = jaccard_similarity ( & query_lower, & document_title_lowercase) ;
439
+
440
+ // --- Combine scores (weights adjustable) ---
441
+ // Levenshtein emphasizes surface similarity
442
+ // Jaccard emphasizes term overlap (semantic hint)
443
+ let hybrid_score = 0.7 * levenshtein_score + 0.3 * jaccard_score;
444
+
445
+ // --- Apply hybrid score ---
446
+ score += hybrid_score;
447
+
448
+ // --- Limit score range ---
449
+ score. min ( 1.0 ) as f64
450
+ } ;
451
+
452
+ hit. score = new_score;
453
+ }
454
+ }
455
+ }
456
+
457
+ /// Compute token-based Jaccard similarity
458
+ fn jaccard_similarity ( a : & str , b : & str ) -> f32 {
459
+ let a_tokens: HashSet < _ > = tokenize ( a) . into_iter ( ) . collect ( ) ;
460
+ let b_tokens: HashSet < _ > = tokenize ( b) . into_iter ( ) . collect ( ) ;
461
+
462
+ if a_tokens. is_empty ( ) || b_tokens. is_empty ( ) {
463
+ return 0.0 ;
464
+ }
465
+
466
+ let intersection = a_tokens. intersection ( & b_tokens) . count ( ) as f32 ;
467
+ let union = a_tokens. union ( & b_tokens) . count ( ) as f32 ;
468
+
469
+ intersection / union
470
+ }
419
471
420
- ( idx, score. min ( 1.0 ) as f64 )
421
- } )
472
+ /// Basic tokenizer (case-insensitive, alphanumeric words only)
473
+ fn tokenize ( text : & str ) -> Vec < String > {
474
+ text. to_lowercase ( )
475
+ . split ( |c : char | !c. is_alphanumeric ( ) )
476
+ . filter ( |s| !s. is_empty ( ) )
477
+ . map ( |s| s. to_string ( ) )
422
478
. collect ( )
423
479
}
424
480
0 commit comments