diff --git a/src/cli/commands.rs b/src/cli/commands.rs index f128499..d740e01 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -531,7 +531,7 @@ fn run_index( let db = Database::new(&config.db_path()?)?; let engine = EmbeddingEngine::new(config)?; - let indexer = Indexer::new(db, engine, max_size); + let mut indexer = Indexer::new(db, engine, max_size); indexer.index_directory(&path, force)?; } } @@ -670,7 +670,7 @@ fn run_search_local( let db = Database::new(&config.db_path()?)?; let engine = EmbeddingEngine::new(config)?; - let search = SearchEngine::new(db, engine, config, config.use_reranker)?; + let mut search = SearchEngine::new(db, engine, config, config.use_reranker)?; if interactive { let mut tui = SearchTui::new(search)?; diff --git a/src/core/embeddings.rs b/src/core/embeddings.rs index be77c5b..ce323c0 100644 --- a/src/core/embeddings.rs +++ b/src/core/embeddings.rs @@ -19,12 +19,23 @@ fn suppress_llama_logs() { }); } -pub struct EmbeddingEngine { +struct EngineResources { backend: LlamaBackend, model: LlamaModel, +} + +pub struct EmbeddingEngine { + // IMPORTANT: ctx must be defined before resources to ensure it is dropped first. + // ctx borrows from model/backend which are in resources. + ctx: Option>, + resources: Box, n_ctx: usize, } +// Ensure we can send it across threads (Mutex wrapper requires this) +// LlamaContext wraps a pointer and is generally thread-safe to move but not to share. +unsafe impl Send for EmbeddingEngine {} + impl EmbeddingEngine { pub fn new(config: &Config) -> Result { let model_path = config.embedding_model_path()?; @@ -41,26 +52,36 @@ impl EmbeddingEngine { let model = LlamaModel::load_from_file(&backend, model_path, &model_params) .context("Failed to load embedding model")?; + let resources = Box::new(EngineResources { backend, model }); + let ctx_params = LlamaContextParams::default() .with_n_threads_batch(n_threads) .with_n_threads(n_threads) .with_embeddings(true); - let ctx = model - .new_context(&backend, ctx_params) - .context("Failed to create context")?; + // SAFETY: We are creating a self-referential struct. + // 1. `resources` is boxed, so its address is stable. + // 2. We extend the lifetime of `model` to 'static temporarily to create `ctx`. + // 3. We store `ctx` in the same struct. + // 4. `ctx` is dropped BEFORE `resources` because it is declared earlier in the struct. + let ctx = unsafe { + let model_ref: &'static LlamaModel = std::mem::transmute(&resources.model); + let backend_ref: &'static LlamaBackend = std::mem::transmute(&resources.backend); + model_ref + .new_context(backend_ref, ctx_params) + .context("Failed to create context")? + }; let n_ctx = std::cmp::min(ctx.n_ctx() as usize, context_size); - drop(ctx); Ok(Self { - backend, - model, + ctx: Some(ctx), + resources, n_ctx, }) } - pub fn embed(&self, text: &str) -> Result> { + pub fn embed(&mut self, text: &str) -> Result> { let embeddings = self.embed_batch(&[text])?; embeddings .into_iter() @@ -68,29 +89,16 @@ impl EmbeddingEngine { .context("No embedding generated") } - pub fn embed_batch(&self, texts: &[&str]) -> Result>> { + pub fn embed_batch(&mut self, texts: &[&str]) -> Result>> { if texts.is_empty() { return Ok(Vec::new()); } - let n_threads = std::thread::available_parallelism() - .map(|p| p.get() as i32) - .unwrap_or(4); - - let ctx_params = LlamaContextParams::default() - .with_n_threads_batch(n_threads) - .with_n_threads(n_threads) - .with_embeddings(true); - - let mut ctx = self - .model - .new_context(&self.backend, ctx_params) - .context("Failed to create context")?; - let mut results = Vec::with_capacity(texts.len()); for text in texts { let tokens = self + .resources .model .str_to_token(text, AddBos::Always) .context("Failed to tokenize")?; @@ -101,7 +109,8 @@ impl EmbeddingEngine { tokens }; - let embedding = self.process_tokens(&mut ctx, &tokens)?; + let ctx = self.ctx.as_mut().context("Context not initialized")?; + let embedding = Self::process_tokens(self.n_ctx, ctx, &tokens)?; results.push(embedding); } @@ -109,11 +118,11 @@ impl EmbeddingEngine { } fn process_tokens( - &self, + n_ctx: usize, ctx: &mut llama_cpp_2::context::LlamaContext, tokens: &[llama_cpp_2::token::LlamaToken], ) -> Result> { - let mut batch = LlamaBatch::new(self.n_ctx, 1); + let mut batch = LlamaBatch::new(n_ctx, 1); batch.add_sequence(tokens, 0, false)?; ctx.clear_kv_cache(); @@ -127,7 +136,7 @@ impl EmbeddingEngine { } pub fn embedding_dim(&self) -> usize { - self.model.n_embd() as usize + self.resources.model.n_embd() as usize } } diff --git a/src/core/indexer.rs b/src/core/indexer.rs index 469c120..30eb1fa 100644 --- a/src/core/indexer.rs +++ b/src/core/indexer.rs @@ -42,7 +42,7 @@ impl Indexer { } } - pub fn index_directory(&self, path: &Path, force: bool) -> Result<()> { + pub fn index_directory(&mut self, path: &Path, force: bool) -> Result<()> { let abs_path = fs::canonicalize(path).context("Failed to resolve path")?; println!( diff --git a/src/core/search.rs b/src/core/search.rs index 1690ba4..87a1bfa 100644 --- a/src/core/search.rs +++ b/src/core/search.rs @@ -35,7 +35,7 @@ impl SearchEngine { } pub fn search( - &self, + &mut self, query: &str, path: &Path, max_results: usize, @@ -90,12 +90,12 @@ impl SearchEngine { Ok(results) } - pub fn search_interactive(&self, query: &str, max_results: usize) -> Result> { + pub fn search_interactive(&mut self, query: &str, max_results: usize) -> Result> { let cwd = std::env::current_dir()?; self.search(query, &cwd, max_results) } - pub fn embed(&self, text: &str) -> Result> { + pub fn embed(&mut self, text: &str) -> Result> { self.embedding_engine.embed(text) } } diff --git a/src/server/api.rs b/src/server/api.rs index bd5cc8a..cc181cc 100644 --- a/src/server/api.rs +++ b/src/server/api.rs @@ -200,7 +200,7 @@ async fn search( // Generate query embedding let query_embedding = { - let engine = match state.embedding_engine.lock() { + let mut engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { return ( @@ -303,7 +303,7 @@ async fn embed( State(state): State, Json(req): Json, ) -> impl IntoResponse { - let engine = match state.embedding_engine.lock() { + let mut engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { return ( @@ -339,7 +339,7 @@ async fn embed_batch( State(state): State, Json(req): Json, ) -> impl IntoResponse { - let engine = match state.embedding_engine.lock() { + let mut engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { return ( diff --git a/src/watcher.rs b/src/watcher.rs index 816219f..3d2facb 100644 --- a/src/watcher.rs +++ b/src/watcher.rs @@ -279,7 +279,7 @@ impl FileWatcher { } let db = Database::new(&self.config.db_path()?)?; let engine = crate::core::EmbeddingEngine::new(&self.config)?; - let indexer = Indexer::new(db, engine, self.config.max_file_size); + let mut indexer = Indexer::new(db, engine, self.config.max_file_size); indexer.index_directory(&self.root_path, false)?; } }