diff --git a/src/core/db.rs b/src/core/db.rs index 080e586..9143bb5 100644 --- a/src/core/db.rs +++ b/src/core/db.rs @@ -44,6 +44,36 @@ pub struct DatabaseStats { pub last_indexed: Option>, } +use std::collections::BinaryHeap; +use std::cmp::Ordering; + +// Wrapper struct for BinaryHeap to keep smallest similarity at the top (Min-Heap behavior for Max-K) +#[derive(Debug)] +struct HeapItem { + result: SearchResult, +} + +impl PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.result.similarity == other.result.similarity + } +} + +impl Eq for HeapItem {} + +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + // Reverse order so the smallest similarity is at the top (popped first) + other.result.similarity.partial_cmp(&self.result.similarity).unwrap_or(Ordering::Equal) + } +} + impl Database { pub fn new(path: &Path) -> Result { let conn = Connection::open(path)?; @@ -168,33 +198,44 @@ impl Database { WHERE f.path LIKE ?", )?; - let mut results: Vec = stmt - .query_map([&like_pattern], |row| { - let embedding_blob: Vec = row.get(6)?; - let embedding = bytes_to_embedding(&embedding_blob); - let similarity = cosine_similarity(query_embedding, &embedding); - - Ok(SearchResult { - chunk_id: row.get(0)?, - file_id: row.get(1)?, - path: PathBuf::from(row.get::<_, String>(2)?), - content: row.get(3)?, - start_line: row.get(4)?, - end_line: row.get(5)?, - similarity, - }) - })? - .filter_map(Result::ok) - .collect(); - - // Sort by similarity (highest first) - results.sort_by(|a, b| { - b.similarity - .partial_cmp(&a.similarity) - .unwrap_or(std::cmp::Ordering::Equal) - }); + // Use a min-heap to keep track of the top K results. + // We only store `limit * 3` items in memory at any given time. + let capacity = limit * 3; + let mut heap: BinaryHeap = BinaryHeap::with_capacity(capacity + 1); + + let rows = stmt.query_map([&like_pattern], |row| { + let embedding_blob: Vec = row.get(6)?; + let embedding = bytes_to_embedding(&embedding_blob); + let similarity = cosine_similarity(query_embedding, &embedding); + + Ok(SearchResult { + chunk_id: row.get(0)?, + file_id: row.get(1)?, + path: PathBuf::from(row.get::<_, String>(2)?), + content: row.get(3)?, + start_line: row.get(4)?, + end_line: row.get(5)?, + similarity, + }) + })?; + + for row in rows { + let result = row?; + + if heap.len() < capacity { + heap.push(HeapItem { result }); + } else if let Some(min_item) = heap.peek() { + // If the current result is better than the worst item in the heap + if result.similarity > min_item.result.similarity { + heap.pop(); + heap.push(HeapItem { result }); + } + } + } - results.truncate(limit * 3); // Get more for reranking + // Convert heap to sorted vector (highest similarity first) + let results: Vec = heap.into_sorted_vec().into_iter().map(|item| item.result).collect(); + Ok(results) }