Skip to content

Commit b8affcd

Browse files
ayangwebSteveLauCmedcl
authored
refactor: improve sorting logic of search results (#910)
* refactor: improve sorting logic of search results * refactor: update * wip * feat: support switching groups via keyboard shortcuts (#911) * feat: support switching groups via keyboard shortcuts * refactor: update * docs: update changelog * refactor post-querying logic * refactor post-querying logic * refactor post-querying logic * refactor post-querying logic * refactor: refactoring rerank function * refactor: refactoring rerank with intelligent hybrid scorer * chore: remove debug logging * chore: fix format --------- Co-authored-by: Steve Lau <[email protected]> Co-authored-by: medcl <[email protected]>
1 parent 595ae67 commit b8affcd

File tree

2 files changed

+180
-122
lines changed

2 files changed

+180
-122
lines changed

src-tauri/src/search/mod.rs

Lines changed: 176 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@ use function_name::named;
1010
use futures::StreamExt;
1111
use futures::stream::FuturesUnordered;
1212
use reqwest::StatusCode;
13-
use std::cmp::Reverse;
1413
use std::collections::HashMap;
15-
use std::collections::HashSet;
1614
use std::sync::Arc;
1715
use tauri::{AppHandle, Manager};
1816
use tokio::time::{Duration, timeout};
19-
2017
#[named]
2118
#[tauri::command]
2219
pub async fn query_coco_fusion(
@@ -187,7 +184,6 @@ async fn query_coco_fusion_multi_query_sources(
187184

188185
let mut futures = FuturesUnordered::new();
189186

190-
let query_source_list_len = query_source_trait_object_list.len();
191187
for query_source_trait_object in query_source_trait_object_list {
192188
let query_source = query_source_trait_object.get_type().clone();
193189
let tauri_app_handle_clone = tauri_app_handle.clone();
@@ -208,14 +204,8 @@ async fn query_coco_fusion_multi_query_sources(
208204
}
209205

210206
let mut total_hits = 0;
211-
let mut need_rerank = true; //TODO set default to false when boost supported in Pizza
212207
let mut failed_requests = Vec::new();
213-
let mut all_hits: Vec<(String, QueryHits, f64)> = Vec::new();
214-
let mut hits_per_source: HashMap<String, Vec<(QueryHits, f64)>> = HashMap::new();
215-
216-
if query_source_list_len > 1 {
217-
need_rerank = true; // If we have more than one source, we need to rerank the hits
218-
}
208+
let mut all_hits_grouped_by_source_id: HashMap<String, Vec<QueryHits>> = HashMap::new();
219209

220210
while let Some((query_source, timeout_result)) = futures.next().await {
221211
match timeout_result {
@@ -246,12 +236,10 @@ async fn query_coco_fusion_multi_query_sources(
246236
document,
247237
};
248238

249-
all_hits.push((source_id.clone(), query_hit.clone(), score));
250-
251-
hits_per_source
239+
all_hits_grouped_by_source_id
252240
.entry(source_id.clone())
253241
.or_insert_with(Vec::new)
254-
.push((query_hit, score));
242+
.push(query_hit);
255243
}
256244
}
257245
Err(search_error) => {
@@ -267,109 +255,117 @@ async fn query_coco_fusion_multi_query_sources(
267255
}
268256
}
269257

270-
// Sort hits within each source by score (descending)
271-
for hits in hits_per_source.values_mut() {
272-
hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Greater));
273-
}
274-
275-
let total_sources = hits_per_source.len();
276-
let max_hits_per_source = if total_sources > 0 {
277-
size as usize / total_sources
278-
} else {
279-
size as usize
280-
};
258+
let n_sources = all_hits_grouped_by_source_id.len();
281259

282-
let mut final_hits = Vec::new();
283-
let mut seen_docs = HashSet::new(); // To track documents we've already added
284-
285-
// Distribute hits fairly across sources
286-
for (_source_id, hits) in &mut hits_per_source {
287-
let take_count = hits.len().min(max_hits_per_source);
288-
for (doc, score) in hits.drain(0..take_count) {
289-
if !seen_docs.contains(&doc.document.id) {
290-
seen_docs.insert(doc.document.id.clone());
291-
log::debug!(
292-
"collect doc: {}, {:?}, {}",
293-
doc.document.id,
294-
doc.document.title,
295-
score
296-
);
297-
final_hits.push(doc);
298-
}
299-
}
260+
if n_sources == 0 {
261+
return Ok(MultiSourceQueryResponse {
262+
failed: Vec::new(),
263+
hits: Vec::new(),
264+
total_hits: 0,
265+
});
300266
}
301267

302-
log::debug!("final hits: {:?}", final_hits.len());
303-
304-
let mut unique_sources = HashSet::new();
305-
for hit in &final_hits {
306-
if let Some(source) = &hit.source {
307-
if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID {
308-
unique_sources.insert(&source.id);
309-
}
310-
}
268+
/*
269+
* Re-rank the hits
270+
*/
271+
if n_sources > 1 {
272+
boosted_levenshtein_rerank(&query_keyword, &mut all_hits_grouped_by_source_id);
311273
}
312274

313-
log::debug!(
314-
"Multiple sources found: {:?}, no rerank needed",
315-
unique_sources
316-
);
317-
318-
if unique_sources.len() < 1 {
319-
need_rerank = false; // If we have hits from multiple sources, we don't need to rerank
275+
/*
276+
* Sort hits within each source by score (descending) in case data sources
277+
* do not sort them
278+
*/
279+
for hits in all_hits_grouped_by_source_id.values_mut() {
280+
hits.sort_by(|a, b| {
281+
b.score
282+
.partial_cmp(&a.score)
283+
.unwrap_or(std::cmp::Ordering::Greater)
284+
});
320285
}
321286

322-
if need_rerank && final_hits.len() > 1 {
323-
// Precollect (index, title)
324-
let titles_to_score: Vec<(usize, &str)> = final_hits
325-
.iter()
326-
.enumerate()
327-
.filter_map(|(idx, hit)| {
328-
let source = hit.source.as_ref()?;
329-
let title = hit.document.title.as_deref()?;
287+
/*
288+
* Collect hits evenly across sources, to ensure:
289+
*
290+
* 1. All sources have hits returned
291+
* 2. Query sources with many hits won't dominate
292+
*/
293+
let mut final_hits_grouped_by_source_id: HashMap<String, Vec<QueryHits>> = HashMap::new();
294+
let mut pruned: HashMap<&str, &[QueryHits]> = HashMap::new();
295+
296+
// max_hits_per_source could be 0, then `final_hits_grouped_by_source_id`
297+
// would be empty. But we don't need to worry about this case as we will
298+
// populate hits later.
299+
let max_hits_per_source = size as usize / n_sources;
300+
for (source_id, hits) in all_hits_grouped_by_source_id.iter() {
301+
let hits_taken = if hits.len() > max_hits_per_source {
302+
pruned.insert(&source_id, &hits[max_hits_per_source..]);
303+
hits[0..max_hits_per_source].to_vec()
304+
} else {
305+
hits.clone()
306+
};
307+
308+
final_hits_grouped_by_source_id.insert(source_id.clone(), hits_taken);
309+
}
330310

331-
if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID {
332-
Some((idx, title))
333-
} else {
334-
None
311+
let final_hits_len = final_hits_grouped_by_source_id
312+
.iter()
313+
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
314+
let pruned_len = pruned
315+
.iter()
316+
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
317+
318+
/*
319+
* If we still need more hits, take the highest-scoring from `pruned`
320+
*
321+
* `pruned` contains sorted arrays, we scan it in a way similar to
322+
* how n-way-merge-sort extracts the element with the greatest value.
323+
*/
324+
if final_hits_len < size as usize {
325+
let n_need = size as usize - final_hits_len;
326+
let n_have = pruned_len;
327+
let n_take = n_have.min(n_need);
328+
329+
for _ in 0..n_take {
330+
let mut highest_score_hit: Option<(&str, &QueryHits)> = None;
331+
for (source_id, sorted_hits) in pruned.iter_mut() {
332+
if sorted_hits.is_empty() {
333+
continue;
335334
}
336-
})
337-
.collect();
338335

339-
// Score them
340-
let scored_hits = boosted_levenshtein_rerank(query_keyword.as_str(), titles_to_score);
336+
let hit = &sorted_hits[0];
341337

342-
// Sort descending by score
343-
let mut scored_hits = scored_hits;
344-
scored_hits.sort_by_key(|&(_, score)| Reverse((score * 1000.0) as u64));
338+
let have_higher_score_hit = match highest_score_hit {
339+
Some((_, current_highest_score_hit)) => {
340+
hit.score > current_highest_score_hit.score
341+
}
342+
None => true,
343+
};
345344

346-
// Apply new scores to final_hits
347-
for (idx, score) in scored_hits.into_iter().take(size as usize) {
348-
final_hits[idx].score = score;
349-
}
350-
} else if final_hits.len() < size as usize {
351-
// If we still need more hits, take the highest-scoring remaining ones
345+
if have_higher_score_hit {
346+
highest_score_hit = Some((*source_id, hit));
352347

353-
let remaining_needed = size as usize - final_hits.len();
348+
// Advance sorted_hits by 1 element, if have
349+
if sorted_hits.len() == 1 {
350+
*sorted_hits = &[];
351+
} else {
352+
*sorted_hits = &sorted_hits[1..];
353+
}
354+
}
355+
}
354356

355-
// Sort all hits by score descending, removing duplicates by document ID
356-
all_hits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
357+
let (source_id, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
357358

358-
let extra_hits = all_hits
359-
.into_iter()
360-
.filter(|(source_id, _, _)| hits_per_source.contains_key(source_id)) // Only take from known sources
361-
.filter_map(|(_, doc, _)| {
362-
if !seen_docs.contains(&doc.document.id) {
363-
seen_docs.insert(doc.document.id.clone());
364-
Some(doc)
365-
} else {
366-
None
367-
}
368-
})
369-
.take(remaining_needed)
370-
.collect::<Vec<_>>();
359+
final_hits_grouped_by_source_id
360+
.get_mut(source_id)
361+
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist")
362+
.push(hit.clone());
363+
}
364+
}
371365

372-
final_hits.extend(extra_hits);
366+
let mut final_hits = Vec::new();
367+
for (_source_id, hits) in final_hits_grouped_by_source_id {
368+
final_hits.extend(hits);
373369
}
374370

375371
// **Sort final hits by score descending**
@@ -379,6 +375,11 @@ async fn query_coco_fusion_multi_query_sources(
379375
.unwrap_or(std::cmp::Ordering::Equal)
380376
});
381377

378+
// Truncate `final_hits` in case it contains more than `size` hits
379+
//
380+
// Technically, we are safe to not do this. But since it is trivial, double-check it.
381+
final_hits.truncate(size as usize);
382+
382383
if final_hits.len() < 5 {
383384
//TODO: Add a recommendation system to suggest more sources
384385
log::info!(
@@ -395,30 +396,85 @@ async fn query_coco_fusion_multi_query_sources(
395396
})
396397
}
397398

398-
fn boosted_levenshtein_rerank(query: &str, titles: Vec<(usize, &str)>) -> Vec<(usize, f64)> {
399-
use strsim::levenshtein;
399+
use std::collections::HashSet;
400+
use strsim::levenshtein;
400401

402+
fn boosted_levenshtein_rerank(
403+
query: &str,
404+
all_hits_grouped_by_source_id: &mut HashMap<String, Vec<QueryHits>>,
405+
) {
401406
let query_lower = query.to_lowercase();
402407

403-
titles
404-
.into_iter()
405-
.map(|(idx, title)| {
406-
let mut score = 0.0;
408+
for (source_id, hits) in all_hits_grouped_by_source_id.iter_mut() {
409+
// Skip special sources like calculator
410+
if source_id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
411+
continue;
412+
}
413+
414+
for hit in hits.iter_mut() {
415+
let document_title = hit.document.title.as_deref().unwrap_or("");
416+
let document_title_lowercase = document_title.to_lowercase();
407417

408-
if title.contains(query) {
409-
score += 0.4;
410-
} else if title.to_lowercase().contains(&query_lower) {
411-
score += 0.2;
412-
}
418+
let new_score = {
419+
let mut score = 0.0;
413420

414-
let dist = levenshtein(&query_lower, &title.to_lowercase());
415-
let max_len = query_lower.len().max(title.len());
416-
if max_len > 0 {
417-
score += (1.0 - (dist as f64 / max_len as f64)) as f32;
418-
}
421+
// --- Exact or substring boost ---
422+
if document_title.contains(query) {
423+
score += 0.4;
424+
} else if document_title_lowercase.contains(&query_lower) {
425+
score += 0.2;
426+
}
427+
428+
// --- Levenshtein distance (character similarity) ---
429+
let dist = levenshtein(&query_lower, &document_title_lowercase);
430+
let max_len = query_lower.len().max(document_title.len());
431+
let levenshtein_score = if max_len > 0 {
432+
(1.0 - (dist as f64 / max_len as f64)) as f32
433+
} else {
434+
0.0
435+
};
436+
437+
// --- Jaccard similarity (token overlap) ---
438+
let jaccard_score = jaccard_similarity(&query_lower, &document_title_lowercase);
439+
440+
// --- Combine scores (weights adjustable) ---
441+
// Levenshtein emphasizes surface similarity
442+
// Jaccard emphasizes term overlap (semantic hint)
443+
let hybrid_score = 0.7 * levenshtein_score + 0.3 * jaccard_score;
444+
445+
// --- Apply hybrid score ---
446+
score += hybrid_score;
447+
448+
// --- Limit score range ---
449+
score.min(1.0) as f64
450+
};
451+
452+
hit.score = new_score;
453+
}
454+
}
455+
}
456+
457+
/// Compute token-based Jaccard similarity
458+
fn jaccard_similarity(a: &str, b: &str) -> f32 {
459+
let a_tokens: HashSet<_> = tokenize(a).into_iter().collect();
460+
let b_tokens: HashSet<_> = tokenize(b).into_iter().collect();
461+
462+
if a_tokens.is_empty() || b_tokens.is_empty() {
463+
return 0.0;
464+
}
465+
466+
let intersection = a_tokens.intersection(&b_tokens).count() as f32;
467+
let union = a_tokens.union(&b_tokens).count() as f32;
468+
469+
intersection / union
470+
}
419471

420-
(idx, score.min(1.0) as f64)
421-
})
472+
/// Basic tokenizer (case-insensitive, alphanumeric words only)
473+
fn tokenize(text: &str) -> Vec<String> {
474+
text.to_lowercase()
475+
.split(|c: char| !c.is_alphanumeric())
476+
.filter(|s| !s.is_empty())
477+
.map(|s| s.to_string())
422478
.collect()
423479
}
424480

src/hooks/useSearch.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { useState, useCallback, useMemo, useRef } from "react";
2-
import { debounce } from "lodash-es";
2+
import { debounce, orderBy } from "lodash-es";
33

44
import type {
55
QueryHits,
@@ -65,7 +65,9 @@ export function useSearch() {
6565
response: MultiSourceQueryResponse,
6666
searchInput: string
6767
) => {
68-
const data = response?.hits || [];
68+
const hits = response?.hits ?? [];
69+
70+
const data = orderBy(hits, "score", "desc");
6971

7072
const searchData = data.reduce(
7173
(acc: SearchDataBySource, item: QueryHits) => {

0 commit comments

Comments
 (0)