Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 67 additions & 26 deletions src/core/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,36 @@ pub struct DatabaseStats {
pub last_indexed: Option<DateTime<Utc>>,
}

use std::collections::BinaryHeap;
use std::cmp::Ordering;

// Wrapper struct for BinaryHeap to keep smallest similarity at the top (Min-Heap behavior for Max-K)
#[derive(Debug)]
struct HeapItem {
result: SearchResult,
}

impl PartialEq for HeapItem {
fn eq(&self, other: &Self) -> bool {
self.result.similarity == other.result.similarity
}
}

impl Eq for HeapItem {}

impl PartialOrd for HeapItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for HeapItem {
fn cmp(&self, other: &Self) -> Ordering {
// Reverse order so the smallest similarity is at the top (popped first)
other.result.similarity.partial_cmp(&self.result.similarity).unwrap_or(Ordering::Equal)
}
}

impl Database {
pub fn new(path: &Path) -> Result<Self> {
let conn = Connection::open(path)?;
Expand Down Expand Up @@ -168,33 +198,44 @@ impl Database {
WHERE f.path LIKE ?",
)?;

let mut results: Vec<SearchResult> = stmt
.query_map([&like_pattern], |row| {
let embedding_blob: Vec<u8> = row.get(6)?;
let embedding = bytes_to_embedding(&embedding_blob);
let similarity = cosine_similarity(query_embedding, &embedding);

Ok(SearchResult {
chunk_id: row.get(0)?,
file_id: row.get(1)?,
path: PathBuf::from(row.get::<_, String>(2)?),
content: row.get(3)?,
start_line: row.get(4)?,
end_line: row.get(5)?,
similarity,
})
})?
.filter_map(Result::ok)
.collect();

// Sort by similarity (highest first)
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Use a min-heap to keep track of the top K results.
// We only store `limit * 3` items in memory at any given time.
let capacity = limit * 3;
let mut heap: BinaryHeap<HeapItem> = BinaryHeap::with_capacity(capacity + 1);

let rows = stmt.query_map([&like_pattern], |row| {
let embedding_blob: Vec<u8> = row.get(6)?;
let embedding = bytes_to_embedding(&embedding_blob);
let similarity = cosine_similarity(query_embedding, &embedding);

Ok(SearchResult {
chunk_id: row.get(0)?,
file_id: row.get(1)?,
path: PathBuf::from(row.get::<_, String>(2)?),
content: row.get(3)?,
start_line: row.get(4)?,
end_line: row.get(5)?,
similarity,
})
})?;

for row in rows {
let result = row?;

if heap.len() < capacity {
heap.push(HeapItem { result });
} else if let Some(min_item) = heap.peek() {
// If the current result is better than the worst item in the heap
if result.similarity > min_item.result.similarity {
heap.pop();
heap.push(HeapItem { result });
}
}
}

results.truncate(limit * 3); // Get more for reranking
// Convert heap to sorted vector (highest similarity first)
let results: Vec<SearchResult> = heap.into_sorted_vec().into_iter().map(|item| item.result).collect();

Ok(results)
}

Expand Down
Loading