Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ fn run_index(

let db = Database::new(&config.db_path()?)?;
let engine = EmbeddingEngine::new(config)?;
let indexer = Indexer::new(db, engine, max_size);
let mut indexer = Indexer::new(db, engine, max_size);
indexer.index_directory(&path, force)?;
}
}
Expand Down Expand Up @@ -670,7 +670,7 @@ fn run_search_local(

let db = Database::new(&config.db_path()?)?;
let engine = EmbeddingEngine::new(config)?;
let search = SearchEngine::new(db, engine, config, config.use_reranker)?;
let mut search = SearchEngine::new(db, engine, config, config.use_reranker)?;

if interactive {
let mut tui = SearchTui::new(search)?;
Expand Down
63 changes: 36 additions & 27 deletions src/core/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,23 @@ fn suppress_llama_logs() {
});
}

pub struct EmbeddingEngine {
struct EngineResources {
backend: LlamaBackend,
model: LlamaModel,
}

pub struct EmbeddingEngine {
// IMPORTANT: ctx must be defined before resources to ensure it is dropped first.
// ctx borrows from model/backend which are in resources.
ctx: Option<llama_cpp_2::context::LlamaContext<'static>>,
resources: Box<EngineResources>,
n_ctx: usize,
}

// Ensure we can send it across threads (Mutex wrapper requires this)
// LlamaContext wraps a pointer and is generally thread-safe to move but not to share.
unsafe impl Send for EmbeddingEngine {}

impl EmbeddingEngine {
pub fn new(config: &Config) -> Result<Self> {
let model_path = config.embedding_model_path()?;
Expand All @@ -41,56 +52,53 @@ impl EmbeddingEngine {
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.context("Failed to load embedding model")?;

let resources = Box::new(EngineResources { backend, model });

let ctx_params = LlamaContextParams::default()
.with_n_threads_batch(n_threads)
.with_n_threads(n_threads)
.with_embeddings(true);

let ctx = model
.new_context(&backend, ctx_params)
.context("Failed to create context")?;
// SAFETY: We are creating a self-referential struct.
// 1. `resources` is boxed, so its address is stable.
// 2. We extend the lifetime of `model` to 'static temporarily to create `ctx`.
// 3. We store `ctx` in the same struct.
// 4. `ctx` is dropped BEFORE `resources` because it is declared earlier in the struct.
let ctx = unsafe {
let model_ref: &'static LlamaModel = std::mem::transmute(&resources.model);
let backend_ref: &'static LlamaBackend = std::mem::transmute(&resources.backend);
model_ref
.new_context(backend_ref, ctx_params)
.context("Failed to create context")?
};

let n_ctx = std::cmp::min(ctx.n_ctx() as usize, context_size);
drop(ctx);

Ok(Self {
backend,
model,
ctx: Some(ctx),
resources,
n_ctx,
})
}

pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
pub fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch(&[text])?;
embeddings
.into_iter()
.next()
.context("No embedding generated")
}

pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
pub fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}

let n_threads = std::thread::available_parallelism()
.map(|p| p.get() as i32)
.unwrap_or(4);

let ctx_params = LlamaContextParams::default()
.with_n_threads_batch(n_threads)
.with_n_threads(n_threads)
.with_embeddings(true);

let mut ctx = self
.model
.new_context(&self.backend, ctx_params)
.context("Failed to create context")?;

let mut results = Vec::with_capacity(texts.len());

for text in texts {
let tokens = self
.resources
.model
.str_to_token(text, AddBos::Always)
.context("Failed to tokenize")?;
Expand All @@ -101,19 +109,20 @@ impl EmbeddingEngine {
tokens
};

let embedding = self.process_tokens(&mut ctx, &tokens)?;
let ctx = self.ctx.as_mut().context("Context not initialized")?;
let embedding = Self::process_tokens(self.n_ctx, ctx, &tokens)?;
results.push(embedding);
}

Ok(results)
}

fn process_tokens(
&self,
n_ctx: usize,
ctx: &mut llama_cpp_2::context::LlamaContext,
tokens: &[llama_cpp_2::token::LlamaToken],
) -> Result<Vec<f32>> {
let mut batch = LlamaBatch::new(self.n_ctx, 1);
let mut batch = LlamaBatch::new(n_ctx, 1);
batch.add_sequence(tokens, 0, false)?;

ctx.clear_kv_cache();
Expand All @@ -127,7 +136,7 @@ impl EmbeddingEngine {
}

pub fn embedding_dim(&self) -> usize {
self.model.n_embd() as usize
self.resources.model.n_embd() as usize
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Indexer {
}
}

pub fn index_directory(&self, path: &Path, force: bool) -> Result<()> {
pub fn index_directory(&mut self, path: &Path, force: bool) -> Result<()> {
let abs_path = fs::canonicalize(path).context("Failed to resolve path")?;

println!(
Expand Down
6 changes: 3 additions & 3 deletions src/core/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl SearchEngine {
}

pub fn search(
&self,
&mut self,
query: &str,
path: &Path,
max_results: usize,
Expand Down Expand Up @@ -90,12 +90,12 @@ impl SearchEngine {
Ok(results)
}

pub fn search_interactive(&self, query: &str, max_results: usize) -> Result<Vec<SearchResult>> {
pub fn search_interactive(&mut self, query: &str, max_results: usize) -> Result<Vec<SearchResult>> {
let cwd = std::env::current_dir()?;
self.search(query, &cwd, max_results)
}

pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
pub fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
self.embedding_engine.embed(text)
}
}
6 changes: 3 additions & 3 deletions src/server/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async fn search(

// Generate query embedding
let query_embedding = {
let engine = match state.embedding_engine.lock() {
let mut engine = match state.embedding_engine.lock() {
Ok(e) => e,
Err(e) => {
return (
Expand Down Expand Up @@ -303,7 +303,7 @@ async fn embed(
State(state): State<SharedState>,
Json(req): Json<EmbedRequest>,
) -> impl IntoResponse {
let engine = match state.embedding_engine.lock() {
let mut engine = match state.embedding_engine.lock() {
Ok(e) => e,
Err(e) => {
return (
Expand Down Expand Up @@ -339,7 +339,7 @@ async fn embed_batch(
State(state): State<SharedState>,
Json(req): Json<EmbedBatchRequest>,
) -> impl IntoResponse {
let engine = match state.embedding_engine.lock() {
let mut engine = match state.embedding_engine.lock() {
Ok(e) => e,
Err(e) => {
return (
Expand Down
2 changes: 1 addition & 1 deletion src/watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ impl FileWatcher {
}
let db = Database::new(&self.config.db_path()?)?;
let engine = crate::core::EmbeddingEngine::new(&self.config)?;
let indexer = Indexer::new(db, engine, self.config.max_file_size);
let mut indexer = Indexer::new(db, engine, self.config.max_file_size);
indexer.index_directory(&self.root_path, false)?;
}
}
Expand Down
Loading