diff --git a/README.md b/README.md index d6fca38..4c7e347 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,23 @@ curl -sSL https://raw.githubusercontent.com/yoanbernabeu/grepai/main/install.sh irm https://raw.githubusercontent.com/yoanbernabeu/grepai/main/install.ps1 | iex ``` -Requires an embedding provider — [Ollama](https://ollama.ai) (default), [LM Studio](https://lmstudio.ai), or OpenAI. +Requires an embedding provider — [Ollama](https://ollama.ai) (default), managed local `llama.cpp`, [LM Studio](https://lmstudio.ai), or OpenAI. **Ollama (recommended):** ```bash ollama pull nomic-embed-text ``` +**Managed local `llama.cpp`:** +```bash +grepai init --provider llamacpp +grepai model install +grepai model use bge-small-en-v1.5-q8_0 +``` + +If you already have managed local models installed, plain `grepai init` will ask which installed `llamacpp` model to use when you choose the `llamacpp` provider. +Managed `llama.cpp` runtime support is currently limited to macOS (`arm64`, `amd64`), Linux (`amd64`), and Windows (`amd64`). + ## Quick Start ```bash @@ -68,7 +78,7 @@ grepai trace callers "Login" # Find who calls a function ## Shell Completion -grepai supports autocompletion for commands, flags, and dynamic values (workspace names, project names, providers, backends). +grepai supports autocompletion for commands, flags, and dynamic values (workspace names, project names, providers, backends, and managed model ids for `llamacpp`). **Zsh (add to `~/.zshrc`):** ```bash diff --git a/cli/completion.go b/cli/completion.go index 65f7d6e..e57ab48 100644 --- a/cli/completion.go +++ b/cli/completion.go @@ -3,6 +3,7 @@ package cli import ( "github.com/spf13/cobra" "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/internal/managedassets" ) var completionCmd = &cobra.Command{ @@ -92,13 +93,17 @@ func init() { func registerCompletions() { // Static flag completions for initCmd _ = initCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return []string{ + completions := []string{ "ollama\tLocal embedding with Ollama", "lmstudio\tLocal embedding with LM Studio", "openai\tCloud embedding with OpenAI", "synthetic\tCloud embedding with Synthetic (free)", "openrouter\tCloud multi-provider gateway", - }, cobra.ShellCompDirectiveNoFileComp + } + if managedLlamaCPPSupported() { + completions = append(completions, "llamacpp\tManaged local embedding with llama.cpp") + } + return completions, cobra.ShellCompDirectiveNoFileComp }) _ = initCmd.RegisterFlagCompletionFunc("backend", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return []string{ @@ -107,6 +112,26 @@ func registerCompletions() { "qdrant\tQdrant vector database", }, cobra.ShellCompDirectiveNoFileComp }) + _ = initCmd.RegisterFlagCompletionFunc("model", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + provider, _ := cmd.Flags().GetString("provider") + switch provider { + case "llamacpp": + return completeAvailableManagedModels(), cobra.ShellCompDirectiveNoFileComp + case "openai": + return []string{ + "text-embedding-3-small\tOpenAI small embedding model", + "text-embedding-3-large\tOpenAI large embedding model", + }, cobra.ShellCompDirectiveNoFileComp + case "openrouter": + return []string{ + "openai/text-embedding-3-small\tOpenRouter small embedding model", + "openai/text-embedding-3-large\tOpenRouter large embedding model", + "qwen/qwen3-embedding-8b\tOpenRouter Qwen code-focused embedding model", + }, cobra.ShellCompDirectiveNoFileComp + default: + return nil, cobra.ShellCompDirectiveNoFileComp + } + }) // Static flag completions for workspaceCreateCmd _ = workspaceCreateCmd.RegisterFlagCompletionFunc("backend", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { @@ -116,15 +141,38 @@ func registerCompletions() { }, cobra.ShellCompDirectiveNoFileComp }) _ = workspaceCreateCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return []string{ + completions := []string{ "ollama\tLocal embedding with Ollama", "lmstudio\tLocal embedding with LM Studio", "openai\tCloud embedding with OpenAI", "synthetic\tCloud embedding with Synthetic (free)", "openrouter\tCloud multi-provider gateway", - }, cobra.ShellCompDirectiveNoFileComp + } + if managedLlamaCPPSupported() { + completions = append(completions, "llamacpp\tManaged local embedding with llama.cpp") + } + return completions, cobra.ShellCompDirectiveNoFileComp }) + modelUseCmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) == 0 { + return completeInstalledManagedModels(), cobra.ShellCompDirectiveNoFileComp + } + return nil, cobra.ShellCompDirectiveNoFileComp + } + modelRemoveCmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) == 0 { + return completeInstalledManagedModels(), cobra.ShellCompDirectiveNoFileComp + } + return nil, cobra.ShellCompDirectiveNoFileComp + } + modelInstallCmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) == 0 { + return completeAvailableManagedModels(), cobra.ShellCompDirectiveNoFileComp + } + return nil, cobra.ShellCompDirectiveNoFileComp + } + // Static flag completions for trace commands (mode) for _, cmd := range []*cobra.Command{traceCallersCmd, traceCalleesCmd, traceGraphCmd} { cmd := cmd @@ -218,3 +266,24 @@ func completeProjectNames(workspaceName string) []string { } return names } + +func completeAvailableManagedModels() []string { + models := managedassets.ListAvailableModels() + completions := make([]string, 0, len(models)) + for _, model := range models { + completions = append(completions, model.ID+"\t"+model.Display) + } + return completions +} + +func completeInstalledManagedModels() []string { + models, err := managedassets.LoadInstalledModels() + if err != nil { + return nil + } + completions := make([]string, 0, len(models)) + for _, model := range models { + completions = append(completions, model.ID+"\tinstalled managed model") + } + return completions +} diff --git a/cli/completion_test.go b/cli/completion_test.go index c9a21f1..1d69380 100644 --- a/cli/completion_test.go +++ b/cli/completion_test.go @@ -8,9 +8,17 @@ import ( "testing" "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/internal/managedassets" ) func TestCompletionZsh_should_output_compdef(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + var buf bytes.Buffer rootCmd.SetOut(&buf) rootCmd.SetArgs([]string{"completion", "zsh"}) @@ -30,6 +38,13 @@ func TestCompletionZsh_should_output_compdef(t *testing.T) { } func TestCompletionBash_should_output_valid_script(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + var buf bytes.Buffer rootCmd.SetOut(&buf) rootCmd.SetArgs([]string{"completion", "bash"}) @@ -45,6 +60,13 @@ func TestCompletionBash_should_output_valid_script(t *testing.T) { } func TestCompletionFish_should_output_valid_script(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + var buf bytes.Buffer rootCmd.SetOut(&buf) rootCmd.SetArgs([]string{"completion", "fish"}) @@ -60,6 +82,13 @@ func TestCompletionFish_should_output_valid_script(t *testing.T) { } func TestCompletionPowershell_should_output_valid_script(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + var buf bytes.Buffer rootCmd.SetOut(&buf) rootCmd.SetArgs([]string{"completion", "powershell"}) @@ -133,3 +162,93 @@ func TestCompleteProjectNames_should_return_project_names(t *testing.T) { t.Fatalf("expected frontend and backend, got: %v", names) } } + +func TestCompletionScriptIncludesLlamaCPPProvider(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetArgs([]string{"__complete", "init", "--provider", ""}) + defer rootCmd.SetOut(nil) + + if err := rootCmd.Execute(); err != nil { + t.Fatalf("provider completion failed: %v", err) + } + + if !strings.Contains(buf.String(), "llamacpp") { + t.Fatalf("expected llamacpp in completion output, got: %s", buf.String()) + } +} + +func TestCompletionInitModelIncludesManagedModelsForLlamaCPP(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetArgs([]string{"__complete", "init", "--provider", "llamacpp", "--model", ""}) + defer rootCmd.SetOut(nil) + + if err := rootCmd.Execute(); err != nil { + t.Fatalf("model completion failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "bge-small-en-v1.5-q8_0") { + t.Fatalf("expected default managed model in completion output, got: %s", output) + } + if !strings.Contains(output, "nomic-embed-text-v1.5-q8_0") { + t.Fatalf("expected Nomic managed model in completion output, got: %s", output) + } +} + +func TestCompletionModelUseIncludesInstalledModels(t *testing.T) { + prevProvider := initProvider + prevModel := initModel + defer func() { + initProvider = prevProvider + initModel = prevModel + }() + + tmpDir := t.TempDir() + oldHome := os.Getenv("HOME") + _ = os.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", oldHome) + + modelDef, err := managedassets.LookupModel("nomic-embed-text-v1.5-q8_0") + if err != nil { + t.Fatalf("LookupModel failed: %v", err) + } + if err := managedassets.SaveInstalledModels([]managedassets.InstalledModel{{ + ID: modelDef.ID, + FileName: modelDef.FileName, + Path: filepath.Join(tmpDir, modelDef.FileName), + SourceURL: modelDef.URL, + SizeBytes: modelDef.SizeBytes, + Dimensions: modelDef.Dimensions, + }}); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetArgs([]string{"__complete", "model", "use", ""}) + defer rootCmd.SetOut(nil) + + if err := rootCmd.Execute(); err != nil { + t.Fatalf("model use completion failed: %v", err) + } + + if !strings.Contains(buf.String(), modelDef.ID) { + t.Fatalf("expected installed model in completion output, got: %s", buf.String()) + } +} diff --git a/cli/init.go b/cli/init.go index 6dd0bff..54a32dd 100644 --- a/cli/init.go +++ b/cli/init.go @@ -3,6 +3,7 @@ package cli import ( "bufio" "fmt" + "io" "os" "strings" @@ -10,6 +11,7 @@ import ( "github.com/yoanbernabeu/grepai/config" "github.com/yoanbernabeu/grepai/git" "github.com/yoanbernabeu/grepai/indexer" + "github.com/yoanbernabeu/grepai/internal/managedassets" ) var ( @@ -39,8 +41,8 @@ This command will: } func init() { - initCmd.Flags().StringVarP(&initProvider, "provider", "p", "", "Embedding provider (ollama, lmstudio, openai, synthetic, or openrouter)") - initCmd.Flags().StringVarP(&initModel, "model", "m", "", "Embedding model (for openai/openrouter: text-embedding-3-small, text-embedding-3-large; openrouter also supports qwen3-embedding-8b)") + initCmd.Flags().StringVarP(&initProvider, "provider", "p", "", "Embedding provider (ollama, llamacpp, lmstudio, openai, synthetic, or openrouter)") + initCmd.Flags().StringVarP(&initModel, "model", "m", "", "Embedding model (for llamacpp: managed model id from 'grepai model list-available'; for openai/openrouter: text-embedding-3-small, text-embedding-3-large; openrouter also supports qwen3-embedding-8b)") initCmd.Flags().StringVarP(&initBackend, "backend", "b", "", "Storage backend (gob, postgres, or qdrant)") initCmd.Flags().BoolVar(&initNonInteractive, "yes", false, "Use defaults without prompting") initCmd.Flags().BoolVar(&initInherit, "inherit", false, "Inherit configuration from main worktree (for git worktrees)") @@ -119,17 +121,29 @@ func runInit(cmd *cobra.Command, args []string) error { if initProvider == "" { fmt.Println("\nSelect embedding provider:") fmt.Println(" 1) ollama (local, privacy-first, requires Ollama running)") - fmt.Println(" 2) lmstudio (local, OpenAI-compatible, requires LM Studio running)") - fmt.Println(" 3) openai (cloud, requires API key)") - fmt.Println(" 4) synthetic (cloud, free embedding API)") - fmt.Println(" 5) openrouter (cloud, multi-provider gateway)") + if managedLlamaCPPSupported() { + fmt.Println(" 2) llamacpp (local, managed runtime + managed model)") + } + fmt.Println(" 3) lmstudio (local, OpenAI-compatible, requires LM Studio running)") + fmt.Println(" 4) openai (cloud, requires API key)") + fmt.Println(" 5) synthetic (cloud, free embedding API)") + fmt.Println(" 6) openrouter (cloud, multi-provider gateway)") fmt.Print("Choice [1]: ") input, _ := reader.ReadString('\n') input = strings.TrimSpace(input) switch input { - case "2", "lmstudio": + case "2", "llamacpp": + if !managedLlamaCPPSupported() { + return managedLlamaCPPUnsupportedError() + } + cfg.Embedder.Provider = "llamacpp" + cfg.Embedder.Model = resolveInteractiveLlamaCPPModel(reader, cmd.OutOrStdout(), initModel) + cfg.Embedder.Endpoint = config.DefaultLlamaCPPEndpoint + dim := resolveLocalModelDimensions(cfg.Embedder.Model) + cfg.Embedder.Dimensions = &dim + case "3", "lmstudio": cfg.Embedder.Provider = "lmstudio" fmt.Print("LM Studio endpoint [http://127.0.0.1:1234]: ") endpoint, _ := reader.ReadString('\n') @@ -141,19 +155,19 @@ func runInit(cmd *cobra.Command, args []string) error { cfg.Embedder.Model = "text-embedding-nomic-embed-text-v1.5" dim := lmStudioEmbeddingDimensions cfg.Embedder.Dimensions = &dim - case "3", "openai": + case "4", "openai": cfg.Embedder.Provider = "openai" cfg.Embedder.Model = config.DefaultOpenAIEmbeddingModel cfg.Embedder.Endpoint = "https://api.openai.com/v1" cfg.Embedder.Parallelism = config.DefaultOpenAIParallelism // OpenAI: leave Dimensions nil to use model's native dimensions - case "4", "synthetic": + case "5", "synthetic": cfg.Embedder.Provider = "synthetic" cfg.Embedder.Model = "hf:nomic-ai/nomic-embed-text-v1.5" cfg.Embedder.Endpoint = "https://api.synthetic.new/openai/v1" dim := 768 cfg.Embedder.Dimensions = &dim - case "5", "openrouter": + case "6", "openrouter": cfg.Embedder.Provider = "openrouter" cfg.Embedder.Endpoint = "https://openrouter.ai/api/v1" // OpenRouter: leave Dimensions nil to use model's native dimensions @@ -189,6 +203,14 @@ func runInit(cmd *cobra.Command, args []string) error { } else { cfg.Embedder.Provider = initProvider switch initProvider { + case "llamacpp": + if !managedLlamaCPPSupported() { + return managedLlamaCPPUnsupportedError() + } + cfg.Embedder.Model = resolveInitModel(initProvider, initModel) + cfg.Embedder.Endpoint = config.DefaultLlamaCPPEndpoint + dim := resolveLocalModelDimensions(cfg.Embedder.Model) + cfg.Embedder.Dimensions = &dim case "lmstudio": cfg.Embedder.Model = "text-embedding-nomic-embed-text-v1.5" cfg.Embedder.Endpoint = "http://127.0.0.1:1234" @@ -276,6 +298,14 @@ func runInit(cmd *cobra.Command, args []string) error { cfg.Embedder.Provider = initProvider // Apply provider-specific settings switch initProvider { + case "llamacpp": + if !managedLlamaCPPSupported() { + return managedLlamaCPPUnsupportedError() + } + cfg.Embedder.Model = resolveInitModel(initProvider, initModel) + cfg.Embedder.Endpoint = config.DefaultLlamaCPPEndpoint + dim := resolveLocalModelDimensions(cfg.Embedder.Model) + cfg.Embedder.Dimensions = &dim case "lmstudio": cfg.Embedder.Model = "text-embedding-nomic-embed-text-v1.5" cfg.Embedder.Endpoint = "http://127.0.0.1:1234" @@ -328,6 +358,15 @@ func runInit(cmd *cobra.Command, args []string) error { case "ollama": fmt.Println("\nMake sure Ollama is running with the nomic-embed-text model:") fmt.Println(" ollama pull nomic-embed-text") + case "llamacpp": + if hasInstalledManagedModel(cfg.Embedder.Model) { + fmt.Printf("\nUsing managed local model: %s\n", cfg.Embedder.Model) + fmt.Println("Switch models later with:") + fmt.Println(" grepai model use ") + } else { + fmt.Println("\nInstall the managed local model before starting watch:") + fmt.Println(" grepai model install") + } case "lmstudio": fmt.Println("\nMake sure LM Studio is running with an embedding model loaded.") fmt.Printf(" Model: %s\n", cfg.Embedder.Model) @@ -352,6 +391,14 @@ func shouldPromptInheritChoice(shouldInherit, nonInteractive, uiMode bool) bool func resolveInitModel(provider, requestedModel string) string { requestedModel = strings.TrimSpace(requestedModel) switch provider { + case "llamacpp": + if requestedModel != "" { + if def, err := managedassets.LookupModel(requestedModel); err == nil { + return def.ID + } + return requestedModel + } + return config.DefaultLlamaCPPEmbeddingModel case "openai": if requestedModel != "" { return requestedModel @@ -372,3 +419,56 @@ func resolveInitModel(provider, requestedModel string) string { return requestedModel } } + +func resolveLocalModelDimensions(model string) int { + if def, err := managedassets.LookupModel(model); err == nil && def.Dimensions > 0 { + return def.Dimensions + } + return config.DefaultLlamaCPPDimensions +} + +func resolveInteractiveLlamaCPPModel(reader *bufio.Reader, out io.Writer, requestedModel string) string { + if model := strings.TrimSpace(requestedModel); model != "" { + return resolveInitModel("llamacpp", model) + } + + installedModels, err := managedassets.LoadInstalledModels() + if err != nil || len(installedModels) == 0 { + return config.DefaultLlamaCPPEmbeddingModel + } + + fmt.Fprintln(out, "\nSelect managed local model:") + defaultChoice := 1 + for i, model := range installedModels { + fmt.Fprintf(out, " %d) %s (%s, %d dims)\n", i+1, model.ID, formatSize(model.SizeBytes), model.Dimensions) + if model.ID == config.DefaultLlamaCPPEmbeddingModel { + defaultChoice = i + 1 + } + } + fmt.Fprintf(out, "Choice [%d]: ", defaultChoice) + + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + if input == "" { + return installedModels[defaultChoice-1].ID + } + for i, model := range installedModels { + if input == fmt.Sprintf("%d", i+1) || input == model.ID { + return model.ID + } + } + return installedModels[defaultChoice-1].ID +} + +func hasInstalledManagedModel(modelID string) bool { + models, err := managedassets.LoadInstalledModels() + if err != nil { + return false + } + for _, model := range models { + if model.ID == modelID { + return true + } + } + return false +} diff --git a/cli/init_test.go b/cli/init_test.go index 2929ac1..b823dfe 100644 --- a/cli/init_test.go +++ b/cli/init_test.go @@ -1,10 +1,16 @@ package cli import ( + "bufio" + "bytes" "os" + "path/filepath" + "runtime" + "strings" "testing" "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/internal/managedassets" ) func withInitTestState(t *testing.T, dir string, configure func()) { @@ -44,6 +50,22 @@ func withInitTestState(t *testing.T, dir string, configure func()) { }) } +func setInitTestHome(t *testing.T, dir string) func() { + t.Helper() + originalHome := os.Getenv("HOME") + if runtime.GOOS == "windows" { + originalProfile := os.Getenv("USERPROFILE") + _ = os.Setenv("USERPROFILE", dir) + return func() { + _ = os.Setenv("USERPROFILE", originalProfile) + } + } + _ = os.Setenv("HOME", dir) + return func() { + _ = os.Setenv("HOME", originalHome) + } +} + func TestRunInit_OpenAIExplicitModelHonored(t *testing.T) { tmpDir := t.TempDir() withInitTestState(t, tmpDir, func() { @@ -95,3 +117,128 @@ func TestRunInit_OpenAIDefaultsToOpenAISmallModel(t *testing.T) { t.Fatalf("parallelism = %d, want %d", cfg.Embedder.Parallelism, config.DefaultOpenAIParallelism) } } + +func TestRunInit_LlamaCPPDefaults(t *testing.T) { + tmpDir := t.TempDir() + withInitTestState(t, tmpDir, func() { + initProvider = "llamacpp" + initBackend = "gob" + initNonInteractive = true + }) + + if err := runInit(nil, nil); err != nil { + t.Fatalf("runInit: %v", err) + } + + cfg, err := config.Load(tmpDir) + if err != nil { + t.Fatalf("config.Load: %v", err) + } + if cfg.Embedder.Provider != "llamacpp" { + t.Fatalf("provider = %q, want llamacpp", cfg.Embedder.Provider) + } + if cfg.Embedder.Model != config.DefaultLlamaCPPEmbeddingModel { + t.Fatalf("model = %q, want %q", cfg.Embedder.Model, config.DefaultLlamaCPPEmbeddingModel) + } + if cfg.Embedder.Endpoint != config.DefaultLlamaCPPEndpoint { + t.Fatalf("endpoint = %q, want %q", cfg.Embedder.Endpoint, config.DefaultLlamaCPPEndpoint) + } +} + +func TestRunInit_LlamaCPPExplicitModelHonored(t *testing.T) { + tmpDir := t.TempDir() + withInitTestState(t, tmpDir, func() { + initProvider = "llamacpp" + initModel = "nomic-embed-text-v1.5-q8_0" + initBackend = "gob" + initNonInteractive = true + }) + + if err := runInit(nil, nil); err != nil { + t.Fatalf("runInit: %v", err) + } + + cfg, err := config.Load(tmpDir) + if err != nil { + t.Fatalf("config.Load: %v", err) + } + if cfg.Embedder.Model != "nomic-embed-text-v1.5-q8_0" { + t.Fatalf("model = %q, want nomic-embed-text-v1.5-q8_0", cfg.Embedder.Model) + } + if cfg.Embedder.Dimensions == nil || *cfg.Embedder.Dimensions != 768 { + t.Fatalf("dimensions = %v, want 768", cfg.Embedder.Dimensions) + } +} + +func TestResolveInteractiveLlamaCPPModelSelectsInstalledModel(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setInitTestHome(t, tmpDir) + defer cleanup() + + modelDef, err := managedassets.LookupModel("nomic-embed-text-v1.5-q8_0") + if err != nil { + t.Fatalf("LookupModel failed: %v", err) + } + modelPath := filepath.Join(tmpDir, modelDef.FileName) + if err := os.WriteFile(modelPath, []byte("stub"), 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := managedassets.SaveInstalledModels([]managedassets.InstalledModel{ + { + ID: managedassets.DefaultModelID, + FileName: "bge-small-en-v1.5-q8_0.gguf", + Path: filepath.Join(tmpDir, "bge-small-en-v1.5-q8_0.gguf"), + SourceURL: "https://example.com/bge", + SizeBytes: 36685152, + Dimensions: 384, + }, + { + ID: modelDef.ID, + FileName: modelDef.FileName, + Path: modelPath, + SourceURL: modelDef.URL, + SizeBytes: modelDef.SizeBytes, + Dimensions: modelDef.Dimensions, + }, + }); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + var out bytes.Buffer + reader := bufio.NewReader(strings.NewReader("2\n")) + selected := resolveInteractiveLlamaCPPModel(reader, &out, "") + if selected != modelDef.ID { + t.Fatalf("selected = %q, want %q", selected, modelDef.ID) + } + if !strings.Contains(out.String(), "Select managed local model") { + t.Fatalf("expected prompt output, got %q", out.String()) + } +} + +func TestHasInstalledManagedModel(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setInitTestHome(t, tmpDir) + defer cleanup() + + modelDef, err := managedassets.LookupModel("nomic-embed-text-v1.5-q8_0") + if err != nil { + t.Fatalf("LookupModel failed: %v", err) + } + if err := managedassets.SaveInstalledModels([]managedassets.InstalledModel{{ + ID: modelDef.ID, + FileName: modelDef.FileName, + Path: filepath.Join(tmpDir, modelDef.FileName), + SourceURL: modelDef.URL, + SizeBytes: modelDef.SizeBytes, + Dimensions: modelDef.Dimensions, + }}); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + if !hasInstalledManagedModel(modelDef.ID) { + t.Fatalf("expected model %q to be reported as installed", modelDef.ID) + } + if hasInstalledManagedModel("missing-model") { + t.Fatal("expected missing model to be reported as not installed") + } +} diff --git a/cli/llamacpp_support.go b/cli/llamacpp_support.go new file mode 100644 index 0000000..7ce66b0 --- /dev/null +++ b/cli/llamacpp_support.go @@ -0,0 +1,33 @@ +package cli + +import ( + "fmt" + "runtime" + + "github.com/yoanbernabeu/grepai/internal/managedassets" +) + +func managedLlamaCPPSupported() bool { + _, err := managedassets.LookupCurrentRuntime() + return err == nil +} + +func managedLlamaCPPUnsupportedError() error { + return fmt.Errorf("managed llama.cpp is not available on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func availableInitProviders() []string { + providers := []string{"ollama", "lmstudio", "openai"} + if managedLlamaCPPSupported() { + providers = []string{"ollama", "llamacpp", "lmstudio", "openai"} + } + return providers +} + +func availableWorkspaceProviders() []string { + providers := []string{"ollama", "openai", "lmstudio"} + if managedLlamaCPPSupported() { + providers = []string{"ollama", "llamacpp", "openai", "lmstudio"} + } + return providers +} diff --git a/cli/model.go b/cli/model.go new file mode 100644 index 0000000..7cff1c5 --- /dev/null +++ b/cli/model.go @@ -0,0 +1,184 @@ +package cli + +import ( + "context" + "fmt" + "text/tabwriter" + + "github.com/spf13/cobra" + "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/internal/managedassets" +) + +var modelCmd = &cobra.Command{ + Use: "model", + Short: "Manage locally installed llama.cpp embedding models", + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + if !managedLlamaCPPSupported() { + return managedLlamaCPPUnsupportedError() + } + return nil + }, +} + +var modelInstallCmd = &cobra.Command{ + Use: "install [model]", + Short: "Install a managed local embedding model", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + modelID := managedassets.DefaultModelID + if len(args) == 1 && args[0] != "" { + modelID = args[0] + } + fmt.Fprintf(cmd.OutOrStdout(), "Installing managed model %s...\n", modelID) + model, err := managedassets.InstallModel(context.Background(), modelID, func(downloaded, total int64) { + renderDownloadProgress("Model", downloaded, total) + }) + fmt.Fprint(cmd.OutOrStdout(), "\r"+progressPadding()+"\r") + if err != nil { + return err + } + fmt.Fprintf(cmd.OutOrStdout(), "Installed model %s at %s\n", model.ID, model.Path) + return nil + }, +} + +var modelListCmd = &cobra.Command{ + Use: "list", + Short: "List installed managed local models", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + models, err := managedassets.LoadInstalledModels() + if err != nil { + return err + } + if len(models) == 0 { + fmt.Fprintln(cmd.OutOrStdout(), "No managed local models installed") + return nil + } + tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 2, 2, ' ', 0) + fmt.Fprintln(tw, "MODEL\tSIZE\tDIMENSIONS\tPATH") + for _, model := range models { + fmt.Fprintf(tw, "%s\t%s\t%d\t%s\n", model.ID, formatSize(model.SizeBytes), model.Dimensions, model.Path) + } + return tw.Flush() + }, +} + +var modelListAvailableCmd = &cobra.Command{ + Use: "list-available", + Short: "List available managed local models", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + models := managedassets.ListAvailableModels() + if len(models) == 0 { + fmt.Fprintln(cmd.OutOrStdout(), "No managed local models available") + return nil + } + tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 2, 2, ' ', 0) + fmt.Fprintln(tw, "MODEL\tSIZE\tDIMENSIONS\tNAME") + for _, model := range models { + fmt.Fprintf(tw, "%s\t%s\t%d\t%s\n", model.ID, formatSize(model.SizeBytes), model.Dimensions, model.Display) + } + return tw.Flush() + }, +} + +var modelUseCmd = &cobra.Command{ + Use: "use ", + Short: "Use an installed managed local model for the current project", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + modelDef, err := managedassets.LookupModel(args[0]) + if err != nil { + return err + } + + installedModels, err := managedassets.LoadInstalledModels() + if err != nil { + return err + } + installed := false + for _, model := range installedModels { + if model.ID == modelDef.ID { + installed = true + break + } + } + if !installed { + return fmt.Errorf("managed model %q is not installed; run 'grepai model install %s'", modelDef.ID, modelDef.ID) + } + + projectRoot, err := config.FindProjectRoot() + if err != nil { + return err + } + cfg, err := config.Load(projectRoot) + if err != nil { + return err + } + + cfg.Embedder.Provider = "llamacpp" + cfg.Embedder.Model = modelDef.ID + cfg.Embedder.ModelPath = "" + cfg.Embedder.Endpoint = config.DefaultLlamaCPPEndpoint + cfg.Embedder.Parallelism = 0 + dim := modelDef.Dimensions + cfg.Embedder.Dimensions = &dim + + if err := cfg.Save(projectRoot); err != nil { + return err + } + + fmt.Fprintf(cmd.OutOrStdout(), "Configured %s to use model %s\n", projectRoot, modelDef.ID) + return nil + }, +} + +var modelRemoveCmd = &cobra.Command{ + Use: "remove ", + Short: "Remove an installed managed local model", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := managedassets.RemoveInstalledModel(args[0]); err != nil { + return err + } + fmt.Fprintf(cmd.OutOrStdout(), "Removed model %s\n", args[0]) + return nil + }, +} + +func init() { + modelCmd.AddCommand(modelInstallCmd) + modelCmd.AddCommand(modelListCmd) + modelCmd.AddCommand(modelListAvailableCmd) + modelCmd.AddCommand(modelUseCmd) + modelCmd.AddCommand(modelRemoveCmd) + rootCmd.AddCommand(modelCmd) +} + +func renderDownloadProgress(label string, downloaded, total int64) { + if total > 0 { + percent := float64(downloaded) / float64(total) * 100 + fmt.Printf("\r%s [%s] %.0f%%", label, progressBar(int(percent), 30), percent) + return + } + fmt.Printf("\r%s %d bytes", label, downloaded) +} + +func progressPadding() string { + return fmt.Sprintf("%60s", "") +} + +func formatSize(sizeBytes int64) string { + switch { + case sizeBytes <= 0: + return "-" + case sizeBytes < 1024*1024: + return fmt.Sprintf("%.0f KB", float64(sizeBytes)/1024) + case sizeBytes < 1024*1024*1024: + return fmt.Sprintf("%.1f MB", float64(sizeBytes)/(1024*1024)) + default: + return fmt.Sprintf("%.2f GB", float64(sizeBytes)/(1024*1024*1024)) + } +} diff --git a/cli/model_test.go b/cli/model_test.go new file mode 100644 index 0000000..881b4af --- /dev/null +++ b/cli/model_test.go @@ -0,0 +1,188 @@ +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/internal/managedassets" +) + +func setModelTestHome(t *testing.T, dir string) func() { + t.Helper() + original := os.Getenv("HOME") + if runtime.GOOS == "windows" { + original = os.Getenv("USERPROFILE") + os.Setenv("USERPROFILE", dir) + return func() { os.Setenv("USERPROFILE", original) } + } + os.Setenv("HOME", dir) + return func() { os.Setenv("HOME", original) } +} + +func TestModelListCommand(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setModelTestHome(t, tmpDir) + defer cleanup() + + models := []managedassets.InstalledModel{{ + ID: managedassets.DefaultModelID, + FileName: "embedding.gguf", + Path: filepath.Join(tmpDir, "embedding.gguf"), + SourceURL: "https://example.com/embedding.gguf", + SizeBytes: 123456, + Dimensions: 768, + }} + if err := managedassets.SaveInstalledModels(models); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + var buf bytes.Buffer + modelListCmd.SetOut(&buf) + modelListCmd.SetArgs(nil) + defer modelListCmd.SetOut(nil) + + if err := modelListCmd.RunE(modelListCmd, nil); err != nil { + t.Fatalf("model list failed: %v", err) + } + if !strings.Contains(buf.String(), managedassets.DefaultModelID) { + t.Fatalf("expected model list output to mention %s, got %q", managedassets.DefaultModelID, buf.String()) + } + if !strings.Contains(buf.String(), "121 KB") { + t.Fatalf("expected model list output to include formatted size, got %q", buf.String()) + } +} + +func TestModelRemoveCommand(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setModelTestHome(t, tmpDir) + defer cleanup() + + modelPath := filepath.Join(tmpDir, "embedding.gguf") + if err := os.WriteFile(modelPath, []byte("stub"), 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := managedassets.SaveInstalledModels([]managedassets.InstalledModel{{ + ID: managedassets.DefaultModelID, + FileName: "embedding.gguf", + Path: modelPath, + SourceURL: "https://example.com/embedding.gguf", + SizeBytes: 4, + Dimensions: 768, + }}); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + if err := modelRemoveCmd.RunE(modelRemoveCmd, []string{managedassets.DefaultModelID}); err != nil { + t.Fatalf("model remove failed: %v", err) + } + models, err := managedassets.LoadInstalledModels() + if err != nil { + t.Fatalf("LoadInstalledModels failed: %v", err) + } + if len(models) != 0 { + t.Fatalf("expected model manifest to be empty, got %+v", models) + } + if _, err := os.Stat(modelPath); !os.IsNotExist(err) { + t.Fatalf("expected model file to be removed, stat err=%v", err) + } +} + +func TestModelListAvailableCommand(t *testing.T) { + var buf bytes.Buffer + modelListAvailableCmd.SetOut(&buf) + modelListAvailableCmd.SetArgs(nil) + defer modelListAvailableCmd.SetOut(nil) + + if err := modelListAvailableCmd.RunE(modelListAvailableCmd, nil); err != nil { + t.Fatalf("model list-available failed: %v", err) + } + if !strings.Contains(buf.String(), managedassets.DefaultModelID) { + t.Fatalf("expected available model output to mention %s, got %q", managedassets.DefaultModelID, buf.String()) + } + if !strings.Contains(buf.String(), "35.0 MB") { + t.Fatalf("expected available model output to include formatted size, got %q", buf.String()) + } + if !strings.Contains(buf.String(), "nomic-embed-text-v1.5-q8_0") { + t.Fatalf("expected available model output to include Nomic option, got %q", buf.String()) + } +} + +func TestModelUseCommand(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setModelTestHome(t, tmpDir) + defer cleanup() + + projectDir := filepath.Join(tmpDir, "project") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + + cfg := config.DefaultConfig() + if err := cfg.Save(projectDir); err != nil { + t.Fatalf("cfg.Save failed: %v", err) + } + + modelDef, err := managedassets.LookupModel("nomic-embed-text-v1.5-q8_0") + if err != nil { + t.Fatalf("LookupModel failed: %v", err) + } + modelPath := filepath.Join(tmpDir, modelDef.FileName) + if err := os.WriteFile(modelPath, []byte("stub"), 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := managedassets.SaveInstalledModels([]managedassets.InstalledModel{{ + ID: modelDef.ID, + FileName: modelDef.FileName, + Path: modelPath, + SourceURL: modelDef.URL, + SizeBytes: int64(len("stub")), + Dimensions: modelDef.Dimensions, + }}); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + + prevCwd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd failed: %v", err) + } + defer func() { + _ = os.Chdir(prevCwd) + }() + if err := os.Chdir(projectDir); err != nil { + t.Fatalf("Chdir failed: %v", err) + } + + var buf bytes.Buffer + modelUseCmd.SetOut(&buf) + modelUseCmd.SetArgs(nil) + defer modelUseCmd.SetOut(nil) + + if err := modelUseCmd.RunE(modelUseCmd, []string{modelDef.ID}); err != nil { + t.Fatalf("model use failed: %v", err) + } + + updated, err := config.Load(projectDir) + if err != nil { + t.Fatalf("config.Load failed: %v", err) + } + if updated.Embedder.Provider != "llamacpp" { + t.Fatalf("provider = %q, want llamacpp", updated.Embedder.Provider) + } + if updated.Embedder.Model != modelDef.ID { + t.Fatalf("model = %q, want %q", updated.Embedder.Model, modelDef.ID) + } + if updated.Embedder.Endpoint != config.DefaultLlamaCPPEndpoint { + t.Fatalf("endpoint = %q, want %q", updated.Embedder.Endpoint, config.DefaultLlamaCPPEndpoint) + } + if updated.Embedder.Dimensions == nil || *updated.Embedder.Dimensions != modelDef.Dimensions { + t.Fatalf("dimensions = %v, want %d", updated.Embedder.Dimensions, modelDef.Dimensions) + } + if !strings.Contains(buf.String(), modelDef.ID) { + t.Fatalf("expected output to mention selected model, got %q", buf.String()) + } +} diff --git a/cli/tui_init.go b/cli/tui_init.go index 6e2b511..3f8ba7f 100644 --- a/cli/tui_init.go +++ b/cli/tui_init.go @@ -26,7 +26,7 @@ const ( initStepReview ) -var initProviderOptions = []string{"ollama", "lmstudio", "openai"} +var initProviderOptions = availableInitProviders() var initBackendOptions = []string{"gob", "postgres", "qdrant"} type initUIModel struct { diff --git a/cli/tui_workspace.go b/cli/tui_workspace.go index 9fbbc92..ea951f1 100644 --- a/cli/tui_workspace.go +++ b/cli/tui_workspace.go @@ -35,6 +35,8 @@ type workspaceCreateModel struct { result *config.Workspace } +var workspaceProviderOptions = availableWorkspaceProviders() + func newWorkspaceCreateModel(workspaceName string) workspaceCreateModel { return workspaceCreateModel{ theme: newTUITheme(), @@ -62,13 +64,13 @@ func (m workspaceCreateModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.step == workspaceStepBackend { m.backendIdx = wrapIndex(m.backendIdx-1, 2) } else if m.step == workspaceStepProvider { - m.providerIdx = wrapIndex(m.providerIdx-1, 3) + m.providerIdx = wrapIndex(m.providerIdx-1, len(workspaceProviderOptions)) } case "down", "j": if m.step == workspaceStepBackend { m.backendIdx = wrapIndex(m.backendIdx+1, 2) } else if m.step == workspaceStepProvider { - m.providerIdx = wrapIndex(m.providerIdx+1, 3) + m.providerIdx = wrapIndex(m.providerIdx+1, len(workspaceProviderOptions)) } case "b": if m.step > workspaceStepBackend { @@ -118,7 +120,7 @@ func (m workspaceCreateModel) renderStep() string { } return strings.Join(lines, "\n") case workspaceStepProvider: - options := []string{"ollama", "openai", "lmstudio"} + options := workspaceProviderOptions lines := []string{m.theme.subtitle.Render("Select embedding provider"), ""} for i, opt := range options { prefix := " " @@ -162,10 +164,12 @@ func buildWorkspaceFromSelection(name string, backendIdx, providerIdx int) *conf } provider := "ollama" - switch providerIdx { - case 1: + switch workspaceProviderOptions[providerIdx] { + case "llamacpp": + provider = "llamacpp" + case "openai": provider = "openai" - case 2: + case "lmstudio": provider = "lmstudio" } diff --git a/cli/tui_workspace_test.go b/cli/tui_workspace_test.go index fbdd18b..9951f5a 100644 --- a/cli/tui_workspace_test.go +++ b/cli/tui_workspace_test.go @@ -20,3 +20,14 @@ func TestBuildWorkspaceFromSelectionMatchesFlagsBuilder(t *testing.T) { t.Fatalf("qdrant port = %d, want %d", ws.Store.Qdrant.Port, ref.Store.Qdrant.Port) } } + +func TestBuildWorkspaceFromSelection_LlamaCPP(t *testing.T) { + ws := buildWorkspaceFromSelection("demo", 1, 1) + + if ws.Embedder.Provider != "llamacpp" { + t.Fatalf("provider = %s, want llamacpp", ws.Embedder.Provider) + } + if ws.Embedder.Model != "bge-small-en-v1.5-q8_0" { + t.Fatalf("model = %s, want bge-small-en-v1.5-q8_0", ws.Embedder.Model) + } +} diff --git a/cli/watch.go b/cli/watch.go index d56c1bc..ce6c160 100644 --- a/cli/watch.go +++ b/cli/watch.go @@ -424,6 +424,12 @@ func initializeEmbedder(ctx context.Context, cfg *config.Config) (embedder.Embed return nil, fmt.Errorf("cannot connect to Ollama: %w\nMake sure Ollama is running and has the %s model", err, cfg.Embedder.Model) } } + case "llamacpp": + if p, ok := emb.(pinger); ok { + if err := p.Ping(ctx); err != nil { + return nil, fmt.Errorf("cannot connect to managed llama.cpp runtime: %w\nInstall the managed model with 'grepai model install' and ensure the runtime can be downloaded for this platform", err) + } + } case "lmstudio": if p, ok := emb.(pinger); ok { if err := p.Ping(ctx); err != nil { diff --git a/cli/workspace.go b/cli/workspace.go index ec453fb..1c2879d 100644 --- a/cli/workspace.go +++ b/cli/workspace.go @@ -99,7 +99,7 @@ func init() { // Non-interactive workspace create flags workspaceCreateCmd.Flags().String("backend", "", "Storage backend: postgres, qdrant") - workspaceCreateCmd.Flags().String("provider", "", "Embedding provider: ollama, openai, lmstudio") + workspaceCreateCmd.Flags().String("provider", "", "Embedding provider: ollama, llamacpp, openai, lmstudio") workspaceCreateCmd.Flags().String("model", "", "Embedding model name") workspaceCreateCmd.Flags().String("endpoint", "", "Embedder endpoint URL") workspaceCreateCmd.Flags().String("dsn", "", "PostgreSQL DSN (when backend=postgres)") @@ -254,10 +254,15 @@ func buildWorkspaceFromFlags(name, backend, provider, model, dsn, endpoint, qdra if provider == "" { provider = "ollama" } + if provider == "llamacpp" && !managedLlamaCPPSupported() { + return nil, managedLlamaCPPUnsupportedError() + } if model == "" { switch provider { case "openai": model = config.DefaultOpenAIEmbeddingModel + case "llamacpp": + model = config.DefaultLlamaCPPEmbeddingModel default: model = config.DefaultOllamaEmbeddingModel } @@ -298,6 +303,13 @@ func buildWorkspaceFromFlags(name, backend, provider, model, dsn, endpoint, qdra embedderConfig.Endpoint = endpoint dim := config.DefaultLocalEmbeddingDimensions embedderConfig.Dimensions = &dim + case "llamacpp": + if endpoint == "" { + endpoint = config.DefaultLlamaCPPEndpoint + } + embedderConfig.Endpoint = endpoint + dim := config.DefaultLlamaCPPDimensions + embedderConfig.Dimensions = &dim case "lmstudio": if endpoint == "" { endpoint = config.DefaultLMStudioEndpoint @@ -312,7 +324,7 @@ func buildWorkspaceFromFlags(name, backend, provider, model, dsn, endpoint, qdra embedderConfig.Endpoint = endpoint embedderConfig.Parallelism = workspaceCreateOpenAIParallelism default: - return nil, fmt.Errorf("unsupported provider: %s (use ollama, openai, or lmstudio)", provider) + return nil, fmt.Errorf("unsupported provider: %s (use ollama, llamacpp, openai, or lmstudio)", provider) } return &config.Workspace{ @@ -504,8 +516,11 @@ func createWorkspaceInteractive(workspaceName string) (*config.Workspace, error) fmt.Println("\nSelect embedding provider:") fmt.Println(" 1. Ollama (local, default)") - fmt.Println(" 2. OpenAI") - fmt.Println(" 3. LM Studio (local)") + if managedLlamaCPPSupported() { + fmt.Println(" 2. llama.cpp (managed local)") + } + fmt.Println(" 3. OpenAI") + fmt.Println(" 4. LM Studio (local)") fmt.Print("Choice [1]: ") embedderChoice, _ := reader.ReadString('\n') embedderChoice = strings.TrimSpace(embedderChoice) @@ -534,6 +549,21 @@ func createWorkspaceInteractive(workspaceName string) (*config.Workspace, error) dim := 768 embedderConfig.Dimensions = &dim case "2": + if !managedLlamaCPPSupported() { + return nil, managedLlamaCPPUnsupportedError() + } + embedderConfig.Provider = "llamacpp" + embedderConfig.Endpoint = config.DefaultLlamaCPPEndpoint + fmt.Printf("Managed model [%s]: ", config.DefaultLlamaCPPEmbeddingModel) + model, _ := reader.ReadString('\n') + model = strings.TrimSpace(model) + if model == "" { + model = config.DefaultLlamaCPPEmbeddingModel + } + embedderConfig.Model = model + dim := config.DefaultLlamaCPPDimensions + embedderConfig.Dimensions = &dim + case "3": embedderConfig.Provider = "openai" fmt.Print("OpenAI API Key: ") apiKey, _ := reader.ReadString('\n') @@ -547,7 +577,7 @@ func createWorkspaceInteractive(workspaceName string) (*config.Workspace, error) embedderConfig.Model = model embedderConfig.Endpoint = config.DefaultOpenAIEndpoint embedderConfig.Parallelism = workspaceCreateOpenAIParallelism - case "3": + case "4": embedderConfig.Provider = "lmstudio" fmt.Print("LM Studio endpoint [http://127.0.0.1:1234]: ") endpoint, _ := reader.ReadString('\n') diff --git a/cli/workspace_create_test.go b/cli/workspace_create_test.go index 7386749..b7e3dd5 100644 --- a/cli/workspace_create_test.go +++ b/cli/workspace_create_test.go @@ -76,6 +76,27 @@ func TestCreateWorkspaceNonInteractive(t *testing.T) { } }) + t.Run("flags_qdrant_llamacpp", func(t *testing.T) { + tmpDir, _ := os.MkdirTemp("", "grepai-test-cli") + defer os.RemoveAll(tmpDir) + cleanup := setTestHomeDirCLI(t, tmpDir) + defer cleanup() + + ws, err := buildWorkspaceFromFlags("test-ws", "qdrant", "llamacpp", "", "", "", "http://localhost", 6334, "", false) + if err != nil { + t.Fatalf("buildWorkspaceFromFlags error: %v", err) + } + if ws.Embedder.Provider != "llamacpp" { + t.Errorf("expected llamacpp provider, got %s", ws.Embedder.Provider) + } + if ws.Embedder.Model != config.DefaultLlamaCPPEmbeddingModel { + t.Errorf("expected default llamacpp model %s, got %s", config.DefaultLlamaCPPEmbeddingModel, ws.Embedder.Model) + } + if ws.Embedder.Endpoint != config.DefaultLlamaCPPEndpoint { + t.Errorf("expected endpoint %s, got %s", config.DefaultLlamaCPPEndpoint, ws.Embedder.Endpoint) + } + }) + t.Run("flags_postgres_openai_default_model_and_parallelism", func(t *testing.T) { tmpDir, _ := os.MkdirTemp("", "grepai-test-cli") defer os.RemoveAll(tmpDir) diff --git a/config/config.go b/config/config.go index 38da06b..8949c95 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,7 @@ const ( DefaultEmbedderProvider = "ollama" DefaultOllamaEmbeddingModel = "nomic-embed-text" + DefaultLlamaCPPEmbeddingModel = "bge-small-en-v1.5-q8_0" DefaultLMStudioEmbeddingModel = "text-embedding-nomic-embed-text-v1.5" DefaultOpenAIEmbeddingModel = "text-embedding-3-small" DefaultSyntheticEmbeddingModel = "hf:nomic-ai/nomic-embed-text-v1.5" @@ -29,12 +30,14 @@ const ( OpenRouterEmbeddingModelQwen8B = "qwen/qwen3-embedding-8b" DefaultOllamaEndpoint = "http://localhost:11434" + DefaultLlamaCPPEndpoint = "http://127.0.0.1:12434" DefaultLMStudioEndpoint = "http://127.0.0.1:1234" DefaultOpenAIEndpoint = "https://api.openai.com/v1" DefaultSyntheticEndpoint = "https://api.synthetic.new/openai/v1" DefaultOpenRouterEndpoint = "https://openrouter.ai/api/v1" DefaultLocalEmbeddingDimensions = 768 + DefaultLlamaCPPDimensions = 384 DefaultOpenAIDimensions = 1536 DefaultOpenAILargeDimensions = 3072 DefaultQwen8BDimensions = 4096 @@ -99,8 +102,9 @@ type BoostRule struct { } type EmbedderConfig struct { - Provider string `yaml:"provider"` // ollama | lmstudio | openai | synthetic | openrouter + Provider string `yaml:"provider"` // ollama | llamacpp | lmstudio | openai | synthetic | openrouter Model string `yaml:"model"` + ModelPath string `yaml:"model_path,omitempty"` Endpoint string `yaml:"endpoint,omitempty"` APIKey string `yaml:"api_key,omitempty"` Dimensions *int `yaml:"dimensions,omitempty"` @@ -154,6 +158,14 @@ func DefaultEmbedderForProvider(provider string) EmbedderConfig { Endpoint: DefaultLMStudioEndpoint, Dimensions: &dim, } + case "llamacpp": + dim := DefaultLlamaCPPDimensions + return EmbedderConfig{ + Provider: "llamacpp", + Model: DefaultLlamaCPPEmbeddingModel, + Endpoint: DefaultLlamaCPPEndpoint, + Dimensions: &dim, + } case "openai": return EmbedderConfig{ Provider: "openai", @@ -452,6 +464,9 @@ func (c *Config) applyDefaults() { if c.Embedder.Endpoint == "" { c.Embedder.Endpoint = DefaultEmbedderForProvider(c.Embedder.Provider).Endpoint } + if c.Embedder.Model == "" { + c.Embedder.Model = DefaultEmbedderForProvider(c.Embedder.Provider).Model + } // Only set default dimensions for local embedders. // For OpenAI/OpenRouter, leave nil to let the API use the model's native dimensions. diff --git a/config/config_test.go b/config/config_test.go index 21bf467..5b32d07 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -77,6 +77,14 @@ func TestDefaultEmbedderForProvider(t *testing.T) { t.Fatalf("unexpected lmstudio dimensions: %v", lmstudio.Dimensions) } + llamacpp := DefaultEmbedderForProvider("llamacpp") + if llamacpp.Endpoint != DefaultLlamaCPPEndpoint || llamacpp.Model != DefaultLlamaCPPEmbeddingModel { + t.Fatalf("unexpected llamacpp defaults: %+v", llamacpp) + } + if llamacpp.Dimensions == nil || *llamacpp.Dimensions != DefaultLlamaCPPDimensions { + t.Fatalf("unexpected llamacpp dimensions: %v", llamacpp.Dimensions) + } + openai := DefaultEmbedderForProvider("openai") if openai.Endpoint != DefaultOpenAIEndpoint || openai.Model != DefaultOpenAIEmbeddingModel { t.Fatalf("unexpected openai defaults: %+v", openai) diff --git a/docs/src/content/docs/backends/embedders.md b/docs/src/content/docs/backends/embedders.md index 3f61a45..6d97dd8 100644 --- a/docs/src/content/docs/backends/embedders.md +++ b/docs/src/content/docs/backends/embedders.md @@ -9,10 +9,88 @@ Embedders convert text (code chunks) into vector representations that enable sem | Provider | Type | Pros | Cons | |----------|------|------|------| +| llama.cpp (managed) | Local | Privacy, no separate service to install, cross-platform managed assets | Larger local downloads, managed runtime still needs compatible platform binaries | | Ollama | Local | Privacy, free, no internet | Requires local resources | | LM Studio | Local | Privacy, OpenAI-compatible API, GUI | Requires local resources | | OpenAI | Cloud | High quality, fast | Costs money, sends code to cloud | +## llama.cpp (Managed Local) + +grepai can manage a local `llama.cpp` embedding runtime for you. Model files and runtime binaries are stored globally under `~/.grepai`, while each project keeps only its local selection in `.grepai/config.yaml`. + +Current managed runtime support: +- macOS `arm64` +- macOS `amd64` +- Linux `amd64` +- Windows `amd64` + +### Setup + +1. Initialize with the managed provider: + +```bash +grepai init --provider llamacpp +``` + +2. Install the recommended default model: + +```bash +grepai model install +``` + +3. Select which installed managed model this project should use: + +```bash +grepai model use bge-small-en-v1.5-q8_0 +``` + +If you already have one or more managed models installed, plain `grepai init` will prompt you to choose one when you select the `llamacpp` provider. + +4. Start indexing normally: + +```bash +grepai watch +``` + +### Configuration + +```yaml +embedder: + provider: llamacpp + model: bge-small-en-v1.5-q8_0 + endpoint: http://127.0.0.1:12434 + dimensions: 384 +``` + +Advanced override with an explicit model path: + +```yaml +embedder: + provider: llamacpp + model: bge-small-en-v1.5-q8_0 + model_path: /absolute/path/to/custom-model.gguf + endpoint: http://127.0.0.1:12434 +``` + +### Managed Assets + +- Models: `~/.grepai/models` +- Runtime binaries: `~/.grepai/bin` +- Runtime metadata/state: `~/.grepai/state` + +### Model Commands + +```bash +grepai model install # Install the recommended default model +grepai model list-available # Show managed model options with file sizes +grepai model install # Install a specific managed model +grepai model list # Show installed managed models +grepai model use # Use an installed managed model for this project +grepai model remove # Remove an installed managed model +``` + +Managed models can carry model-specific embedding behavior. For example, Nomic models use `search_document:` for indexed chunks and `search_query:` for user queries automatically when selected via the managed `llama.cpp` provider. + ## Ollama (Local) ### Setup diff --git a/embedder/embedder.go b/embedder/embedder.go index d72f762..351c2ba 100644 --- a/embedder/embedder.go +++ b/embedder/embedder.go @@ -2,6 +2,14 @@ package embedder import "context" +type InputRole string + +const ( + RoleGeneric InputRole = "generic" + RoleDocument InputRole = "document" + RoleQuery InputRole = "query" +) + // Embedder defines the interface for text embedding providers type Embedder interface { // Embed converts text into a vector embedding @@ -17,6 +25,15 @@ type Embedder interface { Close() error } +// RoleAwareEmbedder optionally supports model-specific input formatting for +// different embedding tasks such as indexing documents vs embedding queries. +type RoleAwareEmbedder interface { + Embedder + + EmbedWithRole(ctx context.Context, text string, role InputRole) ([]float32, error) + EmbedBatchWithRole(ctx context.Context, texts []string, role InputRole) ([][]float32, error) +} + // BatchProgress is a callback for reporting batch embedding progress. // It receives the batch index, total batches, chunk progress info, and optional retry information. // completedChunks and totalChunks track overall progress across all batches. diff --git a/embedder/embedder_test.go b/embedder/embedder_test.go index 2a54705..c547889 100644 --- a/embedder/embedder_test.go +++ b/embedder/embedder_test.go @@ -1,9 +1,30 @@ package embedder import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" "testing" + "time" + + "github.com/yoanbernabeu/grepai/internal/managedassets" ) +func setEmbedderTestHome(t *testing.T, dir string) func() { + t.Helper() + if runtime.GOOS == "windows" { + original := os.Getenv("USERPROFILE") + _ = os.Setenv("USERPROFILE", dir) + return func() { _ = os.Setenv("USERPROFILE", original) } + } + original := os.Getenv("HOME") + _ = os.Setenv("HOME", dir) + return func() { _ = os.Setenv("HOME", original) } +} + // Test OllamaEmbedder options func TestNewOllamaEmbedder_Defaults(t *testing.T) { e := NewOllamaEmbedder() @@ -113,6 +134,81 @@ func TestNewLMStudioEmbedder_WithOptions(t *testing.T) { } } +func TestLlamaCPPEmbedder_AppliesRolePrefixes(t *testing.T) { + e := &LlamaCPPEmbedder{ + model: "nomic-embed-text-v1.5-q8_0", + queryPrefix: "search_query: ", + docPrefix: "search_document: ", + } + + if got := e.applyRolePrefix("hello", RoleQuery); got != "search_query: hello" { + t.Fatalf("query prefix = %q", got) + } + if got := e.applyRolePrefix("chunk", RoleDocument); got != "search_document: chunk" { + t.Fatalf("document prefix = %q", got) + } + if got := e.applyRolePrefix("search_query: hello", RoleQuery); got != "search_query: hello" { + t.Fatalf("query prefix duplicated: %q", got) + } +} + +func TestNewLlamaCPPEmbedder_LoadsNomicModelMetadata(t *testing.T) { + tmpDir := t.TempDir() + modelPath := filepath.Join(tmpDir, "nomic.gguf") + e, err := NewLlamaCPPEmbedder( + WithLlamaCPPModel("nomic-embed-text-v1.5-q8_0"), + WithLlamaCPPModelPath(modelPath), + ) + if err != nil { + t.Fatalf("NewLlamaCPPEmbedder failed: %v", err) + } + if e.queryPrefix != "search_query: " { + t.Fatalf("query prefix = %q", e.queryPrefix) + } + if e.docPrefix != "search_document: " { + t.Fatalf("doc prefix = %q", e.docPrefix) + } +} + +func TestLlamaCPPEmbedder_EnsureRunningReusesHealthyEndpointWithoutPIDProbe(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setEmbedderTestHome(t, tmpDir) + defer cleanup() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + w.WriteHeader(http.StatusOK) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + state := managedassets.RuntimeState{ + Version: managedassets.DefaultRuntimeVersion, + Platform: runtime.GOOS, + Arch: runtime.GOARCH, + Binary: "/tmp/fake-llama-server", + Endpoint: server.URL, + PID: 999999, + } + if err := managedassets.SaveRuntimeState(state); err != nil { + t.Fatalf("SaveRuntimeState failed: %v", err) + } + + e := &LlamaCPPEmbedder{ + runtimePath: "/tmp/fake-llama-server", + endpoint: server.URL, + client: server.Client(), + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := e.ensureRunning(ctx); err != nil { + t.Fatalf("ensureRunning failed: %v", err) + } +} + func TestLMStudioEmbedder_Dimensions(t *testing.T) { tests := []struct { name string diff --git a/embedder/factory.go b/embedder/factory.go index a200a43..645940d 100644 --- a/embedder/factory.go +++ b/embedder/factory.go @@ -33,6 +33,17 @@ func NewFromConfig(cfg *config.Config) (Embedder, error) { } return NewOpenAIEmbedder(opts...) + case "llamacpp": + opts := []LlamaCPPOption{ + WithLlamaCPPModel(cfg.Embedder.Model), + WithLlamaCPPModelPath(cfg.Embedder.ModelPath), + WithLlamaCPPEndpoint(cfg.Embedder.Endpoint), + } + if cfg.Embedder.Dimensions != nil { + opts = append(opts, WithLlamaCPPDimensions(*cfg.Embedder.Dimensions)) + } + return NewLlamaCPPEmbedder(opts...) + case "lmstudio": opts := []LMStudioOption{ WithLMStudioEndpoint(cfg.Embedder.Endpoint), diff --git a/embedder/factory_test.go b/embedder/factory_test.go index a0e18bc..02e37c8 100644 --- a/embedder/factory_test.go +++ b/embedder/factory_test.go @@ -1,6 +1,7 @@ package embedder import ( + "path/filepath" "testing" "github.com/yoanbernabeu/grepai/config" @@ -76,6 +77,33 @@ func TestNewFromConfig_LMStudio(t *testing.T) { } } +func TestNewFromConfig_LlamaCPP(t *testing.T) { + tmpDir := t.TempDir() + modelPath := filepath.Join(tmpDir, "embedding.gguf") + cfg := &config.Config{ + Embedder: config.EmbedderConfig{ + Provider: "llamacpp", + Model: config.DefaultLlamaCPPEmbeddingModel, + ModelPath: modelPath, + Endpoint: config.DefaultLlamaCPPEndpoint, + }, + } + + emb, err := NewFromConfig(cfg) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + defer emb.Close() + + llamaEmb, ok := emb.(*LlamaCPPEmbedder) + if !ok { + t.Errorf("expected *LlamaCPPEmbedder, got %T", emb) + } + if llamaEmb.modelPath != modelPath { + t.Errorf("expected model path %s, got %s", modelPath, llamaEmb.modelPath) + } +} + func TestNewFromConfig_Synthetic(t *testing.T) { t.Setenv("SYNTHETIC_API_KEY", "test-key") diff --git a/embedder/llamacpp.go b/embedder/llamacpp.go new file mode 100644 index 0000000..25adb4c --- /dev/null +++ b/embedder/llamacpp.go @@ -0,0 +1,367 @@ +package embedder + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" + + "github.com/yoanbernabeu/grepai/internal/managedassets" +) + +const ( + defaultLlamaCPPModel = managedassets.DefaultModelID +) + +type LlamaCPPEmbedder struct { + model string + modelPath string + endpoint string + dimensions int + runtimePath string + queryPrefix string + docPrefix string + client *http.Client +} + +type LlamaCPPOption func(*LlamaCPPEmbedder) + +type llamaCPPEmbedRequest struct { + Content string `json:"content,omitempty"` + Input string `json:"input,omitempty"` +} + +type llamaCPPEmbedResponse struct { + Embedding []float32 `json:"embedding"` + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` +} + +func WithLlamaCPPModel(model string) LlamaCPPOption { + return func(e *LlamaCPPEmbedder) { + e.model = model + } +} + +func WithLlamaCPPModelPath(path string) LlamaCPPOption { + return func(e *LlamaCPPEmbedder) { + e.modelPath = path + } +} + +func WithLlamaCPPEndpoint(endpoint string) LlamaCPPOption { + return func(e *LlamaCPPEmbedder) { + e.endpoint = endpoint + } +} + +func WithLlamaCPPDimensions(dimensions int) LlamaCPPOption { + return func(e *LlamaCPPEmbedder) { + e.dimensions = dimensions + } +} + +func WithLlamaCPPRuntimePath(path string) LlamaCPPOption { + return func(e *LlamaCPPEmbedder) { + e.runtimePath = path + } +} + +func NewLlamaCPPEmbedder(opts ...LlamaCPPOption) (*LlamaCPPEmbedder, error) { + e := &LlamaCPPEmbedder{ + model: defaultLlamaCPPModel, + endpoint: managedassets.DefaultSidecarEndpoint(), + dimensions: 384, + client: &http.Client{ + Timeout: 60 * time.Second, + }, + } + for _, opt := range opts { + opt(e) + } + if e.dimensions <= 0 { + e.dimensions = 768 + } + modelPath, dims, err := managedassets.ResolveModelPath(e.model, e.modelPath) + if err != nil { + return nil, err + } + modelDef, err := managedassets.LookupModel(e.model) + if err == nil { + e.queryPrefix = modelDef.QueryPrefix + e.docPrefix = modelDef.DocPrefix + } + e.modelPath = modelPath + if e.dimensions == 384 && dims > 0 { + e.dimensions = dims + } + if e.runtimePath == "" { + runtimeDef, err := managedassets.LookupCurrentRuntime() + if err != nil { + return nil, err + } + runtimePath, err := managedassets.ManagedRuntimeBinaryPath(runtimeDef) + if err != nil { + return nil, err + } + e.runtimePath = runtimePath + } + return e, nil +} + +func (e *LlamaCPPEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + return e.EmbedWithRole(ctx, text, RoleGeneric) +} + +func (e *LlamaCPPEmbedder) EmbedWithRole(ctx context.Context, text string, role InputRole) ([]float32, error) { + if err := e.ensureRunning(ctx); err != nil { + return nil, err + } + text = e.applyRolePrefix(text, role) + body, err := json.Marshal(llamaCPPEmbedRequest{ + Content: text, + Input: text, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal llama.cpp request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(e.endpoint, "/")+"/embedding", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create llama.cpp request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := e.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request to llama.cpp: %w", err) + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read llama.cpp response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("llama.cpp returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + var result llamaCPPEmbedResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to decode llama.cpp response: %w", err) + } + switch { + case len(result.Embedding) > 0: + return result.Embedding, nil + case len(result.Data) > 0 && len(result.Data[0].Embedding) > 0: + return result.Data[0].Embedding, nil + default: + return nil, fmt.Errorf("llama.cpp returned empty embedding") + } +} + +func (e *LlamaCPPEmbedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + return e.EmbedBatchWithRole(ctx, texts, RoleGeneric) +} + +func (e *LlamaCPPEmbedder) EmbedBatchWithRole(ctx context.Context, texts []string, role InputRole) ([][]float32, error) { + results := make([][]float32, len(texts)) + for i, text := range texts { + embedding, err := e.EmbedWithRole(ctx, text, role) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + results[i] = embedding + } + return results, nil +} + +func (e *LlamaCPPEmbedder) applyRolePrefix(text string, role InputRole) string { + switch role { + case RoleQuery: + if e.queryPrefix != "" && !strings.HasPrefix(text, e.queryPrefix) { + return e.queryPrefix + text + } + case RoleDocument: + if e.docPrefix != "" && !strings.HasPrefix(text, e.docPrefix) { + return e.docPrefix + text + } + } + return text +} + +func (e *LlamaCPPEmbedder) Dimensions() int { + return e.dimensions +} + +func (e *LlamaCPPEmbedder) Close() error { + return nil +} + +func (e *LlamaCPPEmbedder) Ping(ctx context.Context) error { + if err := e.ensureRunning(ctx); err != nil { + return err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(e.endpoint, "/")+"/health", nil) + if err != nil { + return fmt.Errorf("failed to create llama.cpp health request: %w", err) + } + resp, err := e.client.Do(req) + if err != nil { + return fmt.Errorf("failed to reach llama.cpp at %s: %w", e.endpoint, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("llama.cpp returned status %d", resp.StatusCode) + } + return nil +} + +func (e *LlamaCPPEmbedder) ensureRunning(ctx context.Context) error { + if ok := waitForHealth(ctx, e.client, e.endpoint, 250*time.Millisecond); ok { + return nil + } + + state, err := managedassets.LoadRuntimeState() + if err != nil { + return err + } + if state != nil && state.Binary == e.runtimePath && state.Endpoint == e.endpoint { + if ok := waitForHealth(ctx, e.client, e.endpoint, 250*time.Millisecond); ok { + return nil + } + } + return e.startSidecar(ctx) +} + +func (e *LlamaCPPEmbedder) startSidecar(ctx context.Context) error { + if err := managedassets.EnsureManagedDirs(); err != nil { + return err + } + runtimePath, _, err := managedassets.EnsureRuntime(ctx, nil) + if err != nil { + return err + } + e.runtimePath = runtimePath + u, err := net.ResolveTCPAddr("tcp", strings.TrimPrefix(strings.TrimPrefix(e.endpoint, "http://"), "https://")) + if err != nil { + return fmt.Errorf("invalid llama.cpp endpoint %s: %w", e.endpoint, err) + } + port := u.Port + host := u.IP.String() + if host == "" || host == "" { + host = "127.0.0.1" + } + cmd := exec.CommandContext(ctx, e.runtimePath, + "--host", host, + "--port", strconv.Itoa(port), + "--model", e.modelPath, + "--embeddings", + "--batch-size", "4096", + "--ubatch-size", "4096", + ) + logPath, err := managedassets.GetManagedRuntimeStatePath() + if err != nil { + return err + } + logPath = strings.TrimSuffix(logPath, ".json") + ".log" + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return fmt.Errorf("failed to open llama.cpp log file: %w", err) + } + cmd.Stdout = logFile + cmd.Stderr = logFile + if err := cmd.Start(); err != nil { + logFile.Close() + return fmt.Errorf("failed to start managed llama.cpp runtime: %w", err) + } + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + _ = logFile.Close() + }() + state := managedassets.RuntimeState{ + Version: managedassets.DefaultRuntimeVersion, + Platform: runtime.GOOS, + Arch: runtime.GOARCH, + Binary: e.runtimePath, + Endpoint: e.endpoint, + PID: cmd.Process.Pid, + Started: time.Now().UTC(), + } + if err := managedassets.SaveRuntimeState(state); err != nil { + return err + } + healthCtx, cancel := context.WithTimeout(ctx, 20*time.Second) + defer cancel() + if err := waitForRuntimeReady(healthCtx, e.client, e.endpoint, done); err != nil { + _ = managedassets.ClearRuntimeState() + return err + } + return nil +} + +func waitForHealth(ctx context.Context, client *http.Client, endpoint string, interval time.Duration) bool { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(endpoint, "/")+"/health", nil) + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return true + } + } + select { + case <-ctx.Done(): + return false + case <-ticker.C: + } + } +} + +func waitForRuntimeReady(ctx context.Context, client *http.Client, endpoint string, done <-chan error) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + if checkHealth(client, endpoint) { + return nil + } + select { + case err := <-done: + if checkHealth(client, endpoint) { + return nil + } + if err != nil { + return fmt.Errorf("managed llama.cpp runtime exited before becoming ready: %w", err) + } + return fmt.Errorf("managed llama.cpp runtime exited before becoming ready") + case <-ctx.Done(): + if checkHealth(client, endpoint) { + return nil + } + return fmt.Errorf("managed llama.cpp runtime did not become ready at %s", endpoint) + case <-ticker.C: + } + } +} + +func checkHealth(client *http.Client, endpoint string) bool { + req, err := http.NewRequest(http.MethodGet, strings.TrimRight(endpoint, "/")+"/health", nil) + if err != nil { + return false + } + resp, err := client.Do(req) + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode == http.StatusOK +} diff --git a/indexer/indexer.go b/indexer/indexer.go index 475cef3..69e2771 100644 --- a/indexer/indexer.go +++ b/indexer/indexer.go @@ -278,6 +278,13 @@ func createStoreChunks(chunkInfos []ChunkInfo, embeddings [][]float32, now time. return chunks, chunkIDs } +func embedContents(ctx context.Context, emb embedder.Embedder, contents []string) ([][]float32, error) { + if roleAware, ok := emb.(embedder.RoleAwareEmbedder); ok { + return roleAware.EmbedBatchWithRole(ctx, contents, embedder.RoleDocument) + } + return emb.EmbedBatch(ctx, contents) +} + // saveFileData saves chunks and document metadata for a single file. func (idx *Indexer) saveFileData(ctx context.Context, fd fileChunkData, chunks []store.Chunk, chunkIDs []string) error { if err := idx.store.SaveChunks(ctx, chunks); err != nil { @@ -570,7 +577,7 @@ func (idx *Indexer) embedWithReChunking(ctx context.Context, chunks []ChunkInfo) contents[i] = c.Content } - vectors, err := idx.embedder.EmbedBatch(ctx, contents) + vectors, err := embedContents(ctx, idx.embedder, contents) if err == nil { // Success! Append all results allVectors = append(allVectors, vectors...) @@ -601,7 +608,7 @@ func (idx *Indexer) embedWithReChunking(ctx context.Context, chunks []ChunkInfo) for i := 0; i < failedIndex; i++ { beforeContents[i] = currentChunks[i].Content } - beforeVectors, err := idx.embedder.EmbedBatch(ctx, beforeContents) + beforeVectors, err := embedContents(ctx, idx.embedder, beforeContents) if err != nil { return nil, nil, fmt.Errorf("failed to embed chunks before failed index: %w", err) } diff --git a/indexer/indexer_test.go b/indexer/indexer_test.go index 5b7535f..f3b05ab 100644 --- a/indexer/indexer_test.go +++ b/indexer/indexer_test.go @@ -194,6 +194,21 @@ func (m *mockEmbedder) Close() error { return nil } +type roleAwareMockEmbedder struct { + mockEmbedder + lastRole embedder.InputRole +} + +func (m *roleAwareMockEmbedder) EmbedWithRole(ctx context.Context, text string, role embedder.InputRole) ([]float32, error) { + m.lastRole = role + return m.Embed(ctx, text) +} + +func (m *roleAwareMockEmbedder) EmbedBatchWithRole(ctx context.Context, texts []string, role embedder.InputRole) ([][]float32, error) { + m.lastRole = role + return m.EmbedBatch(ctx, texts) +} + // TestIndexAllWithProgress_UnchangedFilesSkipped tests that files with matching ModTimes are skipped func TestIndexAllWithProgress_UnchangedFilesSkipped(t *testing.T) { tmpDir := t.TempDir() @@ -1165,3 +1180,14 @@ func TestEmbedWithReChunking_ReChunksOnError(t *testing.T) { t.Errorf("vectors count %d != chunks count %d", len(vectors), len(finalChunks)) } } + +func TestEmbedContentsUsesDocumentRoleWhenSupported(t *testing.T) { + mockEmb := &roleAwareMockEmbedder{} + _, err := embedContents(context.Background(), mockEmb, []string{"hello"}) + if err != nil { + t.Fatalf("embedContents failed: %v", err) + } + if mockEmb.lastRole != embedder.RoleDocument { + t.Fatalf("expected document role, got %s", mockEmb.lastRole) + } +} diff --git a/internal/managedassets/assets.go b/internal/managedassets/assets.go new file mode 100644 index 0000000..0c5400f --- /dev/null +++ b/internal/managedassets/assets.go @@ -0,0 +1,675 @@ +package managedassets + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "time" + + "github.com/yoanbernabeu/grepai/config" +) + +const ( + DefaultModelID = "bge-small-en-v1.5-q8_0" + DefaultRuntimeVersion = "b3426" + DefaultSidecarPort = 12434 + runtimeStateFileName = "llamacpp-runtime.json" + modelManifestFileName = "models.json" + runtimeDownloadTimeout = 10 * time.Minute + modelDownloadTimeout = 30 * time.Minute + defaultEmbeddingDimSize = 768 +) + +type ModelDefinition struct { + ID string `json:"id"` + Display string `json:"display"` + SizeBytes int64 `json:"size_bytes"` + FileName string `json:"file_name"` + URL string `json:"url"` + SHA256 string `json:"sha256,omitempty"` + Dimensions int `json:"dimensions"` + QueryPrefix string `json:"query_prefix,omitempty"` + DocPrefix string `json:"doc_prefix,omitempty"` +} + +type RuntimeDefinition struct { + Version string `json:"version"` + Platform string `json:"platform"` + Arch string `json:"arch"` + URL string `json:"url"` + SHA256 string `json:"sha256,omitempty"` + Archive string `json:"archive"` + Binary string `json:"binary"` +} + +type InstalledModel struct { + ID string `json:"id"` + FileName string `json:"file_name"` + Path string `json:"path"` + SourceURL string `json:"source_url"` + Installed time.Time `json:"installed_at"` + SizeBytes int64 `json:"size_bytes"` + Dimensions int `json:"dimensions"` +} + +type RuntimeState struct { + Version string `json:"version"` + Platform string `json:"platform"` + Arch string `json:"arch"` + Binary string `json:"binary"` + Endpoint string `json:"endpoint"` + PID int `json:"pid"` + Started time.Time `json:"started_at"` +} + +var defaultModels = map[string]ModelDefinition{ + DefaultModelID: { + ID: DefaultModelID, + Display: "BGE Small English v1.5 Q8_0", + SizeBytes: 36685152, + FileName: "bge-small-en-v1.5-q8_0.gguf", + URL: "https://huggingface.co/ggml-org/bge-small-en-v1.5-Q8_0-GGUF/resolve/main/bge-small-en-v1.5-q8_0.gguf?download=1", + Dimensions: 384, + }, + "nomic-embed-text-v1.5-q8_0": { + ID: "nomic-embed-text-v1.5-q8_0", + Display: "Nomic Embed Text v1.5 Q8_0", + SizeBytes: 153092096, + FileName: "nomic-embed-text-v1.5.Q8_0.gguf", + URL: "https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q8_0.gguf?download=1", + Dimensions: 768, + QueryPrefix: "search_query: ", + DocPrefix: "search_document: ", + }, + "nomic-embed-text-v1.5-q4_k_m": { + ID: "nomic-embed-text-v1.5-q4_k_m", + Display: "Nomic Embed Text v1.5 Q4_K_M", + SizeBytes: 88185242, + FileName: "nomic-embed-text-v1.5.Q4_K_M.gguf", + URL: "https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q4_K_M.gguf?download=1", + Dimensions: 768, + QueryPrefix: "search_query: ", + DocPrefix: "search_document: ", + }, +} + +var runtimeDefinitions = map[string]RuntimeDefinition{ + "darwin/arm64": { + Version: DefaultRuntimeVersion, + Platform: "darwin", + Arch: "arm64", + URL: "https://github.com/ggml-org/llama.cpp/releases/download/b3426/llama-b3426-bin-macos-arm64.zip", + Archive: "zip", + Binary: "llama-server", + }, + "darwin/amd64": { + Version: DefaultRuntimeVersion, + Platform: "darwin", + Arch: "amd64", + URL: "https://github.com/ggml-org/llama.cpp/releases/download/b3426/llama-b3426-bin-macos-x64.zip", + Archive: "zip", + Binary: "llama-server", + }, + "linux/amd64": { + Version: DefaultRuntimeVersion, + Platform: "linux", + Arch: "amd64", + URL: "https://github.com/ggml-org/llama.cpp/releases/download/b3426/llama-b3426-bin-ubuntu-x64.zip", + Archive: "zip", + Binary: "llama-server", + }, + "windows/amd64": { + Version: DefaultRuntimeVersion, + Platform: "windows", + Arch: "amd64", + URL: "https://github.com/ggml-org/llama.cpp/releases/download/b3426/llama-b3426-bin-win-avx2-x64.zip", + Archive: "zip", + Binary: "llama-server.exe", + }, +} + +func GetManagedBinDir() (string, error) { + root, err := config.GetGlobalConfigDir() + if err != nil { + return "", err + } + return filepath.Join(root, "bin"), nil +} + +func GetManagedModelsDir() (string, error) { + root, err := config.GetGlobalConfigDir() + if err != nil { + return "", err + } + return filepath.Join(root, "models"), nil +} + +func GetManagedStateDir() (string, error) { + root, err := config.GetGlobalConfigDir() + if err != nil { + return "", err + } + return filepath.Join(root, "state"), nil +} + +func GetManagedRuntimeStatePath() (string, error) { + dir, err := GetManagedStateDir() + if err != nil { + return "", err + } + return filepath.Join(dir, runtimeStateFileName), nil +} + +func GetManagedModelManifestPath() (string, error) { + dir, err := GetManagedModelsDir() + if err != nil { + return "", err + } + return filepath.Join(dir, modelManifestFileName), nil +} + +func DefaultSidecarEndpoint() string { + return fmt.Sprintf("http://127.0.0.1:%d", DefaultSidecarPort) +} + +func LookupModel(id string) (ModelDefinition, error) { + if id == "" { + id = DefaultModelID + } + def, ok := defaultModels[id] + if !ok { + return ModelDefinition{}, fmt.Errorf("unknown managed model: %s", id) + } + return def, nil +} + +func ListAvailableModels() []ModelDefinition { + models := make([]ModelDefinition, 0, len(defaultModels)) + for _, model := range defaultModels { + models = append(models, model) + } + slices.SortFunc(models, func(a, b ModelDefinition) int { + return strings.Compare(a.ID, b.ID) + }) + return models +} + +func LookupRuntime(goos, goarch string) (RuntimeDefinition, error) { + key := goos + "/" + goarch + def, ok := runtimeDefinitions[key] + if !ok { + return RuntimeDefinition{}, fmt.Errorf("managed llama.cpp runtime is not available for %s/%s", goos, goarch) + } + return def, nil +} + +func LookupCurrentRuntime() (RuntimeDefinition, error) { + return LookupRuntime(runtime.GOOS, runtime.GOARCH) +} + +func EnsureManagedDirs() error { + for _, fn := range []func() (string, error){GetManagedBinDir, GetManagedModelsDir, GetManagedStateDir} { + dir, err := fn() + if err != nil { + return err + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create managed assets directory %s: %w", dir, err) + } + } + return nil +} + +func ManagedModelPath(def ModelDefinition) (string, error) { + dir, err := GetManagedModelsDir() + if err != nil { + return "", err + } + return filepath.Join(dir, def.FileName), nil +} + +func ManagedRuntimeBinaryPath(def RuntimeDefinition) (string, error) { + dir, err := GetManagedBinDir() + if err != nil { + return "", err + } + return filepath.Join(dir, def.Binary), nil +} + +func LoadInstalledModels() ([]InstalledModel, error) { + manifestPath, err := GetManagedModelManifestPath() + if err != nil { + return nil, err + } + data, err := os.ReadFile(manifestPath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read model manifest: %w", err) + } + var models []InstalledModel + if err := json.Unmarshal(data, &models); err != nil { + return nil, fmt.Errorf("failed to parse model manifest: %w", err) + } + for i := range models { + if models[i].SizeBytes <= 0 { + if st, err := os.Stat(models[i].Path); err == nil { + models[i].SizeBytes = st.Size() + } else if def, ok := defaultModels[models[i].ID]; ok && def.SizeBytes > 0 { + models[i].SizeBytes = def.SizeBytes + } + } + if models[i].Dimensions <= 0 { + if def, ok := defaultModels[models[i].ID]; ok && def.Dimensions > 0 { + models[i].Dimensions = def.Dimensions + } + } + } + return models, nil +} + +func SaveInstalledModels(models []InstalledModel) error { + if err := EnsureManagedDirs(); err != nil { + return err + } + manifestPath, err := GetManagedModelManifestPath() + if err != nil { + return err + } + data, err := json.MarshalIndent(models, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal model manifest: %w", err) + } + return os.WriteFile(manifestPath, data, 0o600) +} + +func FindInstalledModel(id string) (*InstalledModel, error) { + models, err := LoadInstalledModels() + if err != nil { + return nil, err + } + for i := range models { + if models[i].ID == id { + return &models[i], nil + } + } + return nil, nil +} + +func InstallModel(ctx context.Context, id string, progress func(downloaded, total int64)) (*InstalledModel, error) { + def, err := LookupModel(id) + if err != nil { + return nil, err + } + if err := EnsureManagedDirs(); err != nil { + return nil, err + } + modelPath, err := ManagedModelPath(def) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, modelDownloadTimeout) + defer cancel() + if err := downloadFile(ctx, def.URL, modelPath, def.SHA256, progress); err != nil { + return nil, err + } + model := InstalledModel{ + ID: def.ID, + FileName: def.FileName, + Path: modelPath, + SourceURL: def.URL, + Installed: time.Now().UTC(), + SizeBytes: def.SizeBytes, + Dimensions: def.Dimensions, + } + models, err := LoadInstalledModels() + if err != nil { + return nil, err + } + replaced := false + for i := range models { + if models[i].ID == model.ID { + models[i] = model + replaced = true + break + } + } + if !replaced { + models = append(models, model) + } + if err := SaveInstalledModels(models); err != nil { + return nil, err + } + return &model, nil +} + +func RemoveInstalledModel(id string) error { + models, err := LoadInstalledModels() + if err != nil { + return err + } + filtered := models[:0] + removed := false + for _, m := range models { + if m.ID == id { + removed = true + if m.Path != "" { + _ = os.Remove(m.Path) + } + continue + } + filtered = append(filtered, m) + } + if !removed { + return fmt.Errorf("managed model %q is not installed", id) + } + return SaveInstalledModels(filtered) +} + +func EnsureRuntime(ctx context.Context, progress func(downloaded, total int64)) (string, RuntimeDefinition, error) { + def, err := LookupCurrentRuntime() + if err != nil { + return "", RuntimeDefinition{}, err + } + if err := EnsureManagedDirs(); err != nil { + return "", RuntimeDefinition{}, err + } + binPath, err := ManagedRuntimeBinaryPath(def) + if err != nil { + return "", RuntimeDefinition{}, err + } + if st, err := os.Stat(binPath); err == nil && st.Mode().IsRegular() { + return binPath, def, nil + } + ctx, cancel := context.WithTimeout(ctx, runtimeDownloadTimeout) + defer cancel() + tmpDir, err := os.MkdirTemp("", "grepai-llamacpp-runtime-*") + if err != nil { + return "", RuntimeDefinition{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + archivePath := filepath.Join(tmpDir, filepath.Base(def.URL)) + if err := downloadFile(ctx, def.URL, archivePath, def.SHA256, progress); err != nil { + return "", RuntimeDefinition{}, err + } + if err := extractArchive(archivePath, tmpDir, def.Archive); err != nil { + return "", RuntimeDefinition{}, err + } + extracted, err := findFile(tmpDir, def.Binary) + if err != nil { + return "", RuntimeDefinition{}, err + } + if err := copyExecutable(extracted, binPath); err != nil { + return "", RuntimeDefinition{}, err + } + return binPath, def, nil +} + +func ResolveModelPath(id, override string) (string, int, error) { + if strings.TrimSpace(override) != "" { + return override, defaultEmbeddingDimSize, nil + } + if id == "" { + id = DefaultModelID + } + installed, err := FindInstalledModel(id) + if err != nil { + return "", 0, err + } + if installed == nil { + return "", 0, fmt.Errorf("managed model %q is not installed; run 'grepai model install %s'", id, id) + } + return installed.Path, installed.Dimensions, nil +} + +func LoadRuntimeState() (*RuntimeState, error) { + path, err := GetManagedRuntimeStatePath() + if err != nil { + return nil, err + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read runtime state: %w", err) + } + var state RuntimeState + if err := json.Unmarshal(data, &state); err != nil { + return nil, fmt.Errorf("failed to parse runtime state: %w", err) + } + return &state, nil +} + +func SaveRuntimeState(state RuntimeState) error { + if err := EnsureManagedDirs(); err != nil { + return err + } + path, err := GetManagedRuntimeStatePath() + if err != nil { + return err + } + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal runtime state: %w", err) + } + return os.WriteFile(path, data, 0o600) +} + +func ClearRuntimeState() error { + path, err := GetManagedRuntimeStatePath() + if err != nil { + return err + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove runtime state: %w", err) + } + return nil +} + +func downloadFile(ctx context.Context, url, dest, checksum string, progress func(downloaded, total int64)) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create download request: %w", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to download %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download failed for %s: status %d", url, resp.StatusCode) + } + if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + tmp := dest + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("failed to create temporary file: %w", err) + } + defer f.Close() + + var r io.Reader = resp.Body + var written int64 + buf := make([]byte, 32*1024) + hash := sha256.New() + total := resp.ContentLength + + for { + n, readErr := r.Read(buf) + if n > 0 { + chunk := buf[:n] + if _, err := f.Write(chunk); err != nil { + return fmt.Errorf("failed to write download: %w", err) + } + if _, err := hash.Write(chunk); err != nil { + return fmt.Errorf("failed to hash download: %w", err) + } + written += int64(n) + if progress != nil { + progress(written, total) + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return fmt.Errorf("failed while downloading %s: %w", url, readErr) + } + } + + if checksum != "" { + got := hex.EncodeToString(hash.Sum(nil)) + if !strings.EqualFold(got, checksum) { + return fmt.Errorf("checksum mismatch for %s", filepath.Base(dest)) + } + } + if err := f.Close(); err != nil { + return fmt.Errorf("failed to finalize download: %w", err) + } + return os.Rename(tmp, dest) +} + +func extractArchive(archivePath, destDir, kind string) error { + switch kind { + case "zip": + return extractZip(archivePath, destDir) + case "tar.gz": + return extractTarGz(archivePath, destDir) + default: + return fmt.Errorf("unsupported archive type: %s", kind) + } +} + +func extractZip(archivePath, destDir string) error { + r, err := zip.OpenReader(archivePath) + if err != nil { + return fmt.Errorf("failed to open zip archive: %w", err) + } + defer r.Close() + for _, f := range r.File { + target := filepath.Join(destDir, f.Name) + if f.FileInfo().IsDir() { + if err := os.MkdirAll(target, 0o755); err != nil { + return err + } + continue + } + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return err + } + rc, err := f.Open() + if err != nil { + return err + } + out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, f.Mode()) + if err != nil { + rc.Close() + return err + } + if _, err := io.Copy(out, rc); err != nil { + out.Close() + rc.Close() + return err + } + out.Close() + rc.Close() + } + return nil +} + +func extractTarGz(archivePath, destDir string) error { + f, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("failed to open tar.gz archive: %w", err) + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gz.Close() + return untar(gz, destDir) +} + +func untar(r io.Reader, destDir string) error { + tr := tar.NewReader(r) + for { + hdr, err := tr.Next() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + target := filepath.Join(destDir, hdr.Name) + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, 0o755); err != nil { + return err + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return err + } + out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(out, tr); err != nil { + out.Close() + return err + } + out.Close() + } + } +} + +func findFile(root, fileName string) (string, error) { + var found string + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if filepath.Base(path) == fileName { + found = path + return io.EOF + } + return nil + }) + if err != nil && err != io.EOF { + return "", err + } + if found == "" { + return "", fmt.Errorf("required file %s not found in extracted archive", fileName) + } + return found, nil +} + +func copyExecutable(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("failed to read executable: %w", err) + } + if err := os.WriteFile(dst, data, 0o755); err != nil { + return fmt.Errorf("failed to install executable: %w", err) + } + return nil +} diff --git a/internal/managedassets/assets_test.go b/internal/managedassets/assets_test.go new file mode 100644 index 0000000..ef165e9 --- /dev/null +++ b/internal/managedassets/assets_test.go @@ -0,0 +1,125 @@ +package managedassets + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func setTestHomeDir(t *testing.T, dir string) func() { + t.Helper() + if runtime.GOOS == "windows" { + original := os.Getenv("USERPROFILE") + os.Setenv("USERPROFILE", dir) + return func() { os.Setenv("USERPROFILE", original) } + } + original := os.Getenv("HOME") + os.Setenv("HOME", dir) + return func() { os.Setenv("HOME", original) } +} + +func TestManagedPaths(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setTestHomeDir(t, tmpDir) + defer cleanup() + + binDir, err := GetManagedBinDir() + if err != nil { + t.Fatalf("GetManagedBinDir failed: %v", err) + } + modelDir, err := GetManagedModelsDir() + if err != nil { + t.Fatalf("GetManagedModelsDir failed: %v", err) + } + if filepath.Base(binDir) != "bin" { + t.Fatalf("expected bin dir, got %s", binDir) + } + if filepath.Base(modelDir) != "models" { + t.Fatalf("expected models dir, got %s", modelDir) + } +} + +func TestSaveAndLoadInstalledModels(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setTestHomeDir(t, tmpDir) + defer cleanup() + + models := []InstalledModel{{ + ID: DefaultModelID, + FileName: "test.gguf", + Path: filepath.Join(tmpDir, "test.gguf"), + SourceURL: "https://example.com/test.gguf", + Dimensions: 768, + }} + if err := SaveInstalledModels(models); err != nil { + t.Fatalf("SaveInstalledModels failed: %v", err) + } + loaded, err := LoadInstalledModels() + if err != nil { + t.Fatalf("LoadInstalledModels failed: %v", err) + } + if len(loaded) != 1 || loaded[0].ID != DefaultModelID { + t.Fatalf("unexpected loaded models: %+v", loaded) + } +} + +func TestLookupCurrentRuntime(t *testing.T) { + if _, err := LookupCurrentRuntime(); err != nil { + t.Fatalf("LookupCurrentRuntime failed for %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } +} + +func TestLookupRuntime_KnownCrossPlatformTargets(t *testing.T) { + targets := [][2]string{ + {"darwin", "arm64"}, + {"darwin", "amd64"}, + {"linux", "amd64"}, + {"windows", "amd64"}, + } + + for _, target := range targets { + def, err := LookupRuntime(target[0], target[1]) + if err != nil { + t.Fatalf("LookupRuntime(%s, %s) failed: %v", target[0], target[1], err) + } + if def.URL == "" || def.Binary == "" { + t.Fatalf("incomplete runtime definition for %s/%s: %+v", target[0], target[1], def) + } + } +} + +func TestRuntimeStateRoundTrip(t *testing.T) { + tmpDir := t.TempDir() + cleanup := setTestHomeDir(t, tmpDir) + defer cleanup() + + state := RuntimeState{ + Version: DefaultRuntimeVersion, + Platform: "darwin", + Arch: "arm64", + Binary: "/tmp/llama-server", + Endpoint: DefaultSidecarEndpoint(), + PID: 12345, + } + if err := SaveRuntimeState(state); err != nil { + t.Fatalf("SaveRuntimeState failed: %v", err) + } + loaded, err := LoadRuntimeState() + if err != nil { + t.Fatalf("LoadRuntimeState failed: %v", err) + } + if loaded == nil || loaded.PID != state.PID || loaded.Endpoint != state.Endpoint { + t.Fatalf("unexpected runtime state: %+v", loaded) + } + if err := ClearRuntimeState(); err != nil { + t.Fatalf("ClearRuntimeState failed: %v", err) + } + loaded, err = LoadRuntimeState() + if err != nil { + t.Fatalf("LoadRuntimeState after clear failed: %v", err) + } + if loaded != nil { + t.Fatalf("expected nil runtime state after clear, got %+v", loaded) + } +} diff --git a/search/search.go b/search/search.go index 5776e78..9f85ba5 100644 --- a/search/search.go +++ b/search/search.go @@ -26,7 +26,7 @@ func NewSearcher(st store.VectorStore, emb embedder.Embedder, searchCfg config.S func (s *Searcher) Search(ctx context.Context, query string, limit int, pathPrefix string) ([]store.SearchResult, error) { // Embed the query - queryVector, err := s.embedder.Embed(ctx, query) + queryVector, err := embedQuery(ctx, s.embedder, query) if err != nil { return nil, err } @@ -59,6 +59,13 @@ func (s *Searcher) Search(ctx context.Context, query string, limit int, pathPref return results, nil } +func embedQuery(ctx context.Context, emb embedder.Embedder, query string) ([]float32, error) { + if roleAware, ok := emb.(embedder.RoleAwareEmbedder); ok { + return roleAware.EmbedWithRole(ctx, query, embedder.RoleQuery) + } + return emb.Embed(ctx, query) +} + // hybridSearch combines vector search and text search using RRF. func (s *Searcher) hybridSearch(ctx context.Context, query string, queryVector []float32, limit int, pathPrefix string) ([]store.SearchResult, error) { // Vector search diff --git a/search/search_test.go b/search/search_test.go new file mode 100644 index 0000000..53e19c8 --- /dev/null +++ b/search/search_test.go @@ -0,0 +1,90 @@ +package search + +import ( + "context" + "testing" + + "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/embedder" + "github.com/yoanbernabeu/grepai/store" +) + +type roleAwareTestEmbedder struct { + lastText string + lastRole embedder.InputRole +} + +func (e *roleAwareTestEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + e.lastText = text + e.lastRole = embedder.RoleGeneric + return []float32{1, 2, 3}, nil +} + +func (e *roleAwareTestEmbedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{1, 2, 3} + } + return out, nil +} + +func (e *roleAwareTestEmbedder) EmbedWithRole(ctx context.Context, text string, role embedder.InputRole) ([]float32, error) { + e.lastText = text + e.lastRole = role + return []float32{1, 2, 3}, nil +} + +func (e *roleAwareTestEmbedder) EmbedBatchWithRole(ctx context.Context, texts []string, role embedder.InputRole) ([][]float32, error) { + e.lastRole = role + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{1, 2, 3} + } + return out, nil +} + +func (e *roleAwareTestEmbedder) Dimensions() int { return 3 } +func (e *roleAwareTestEmbedder) Close() error { return nil } + +type searchStoreStub struct{} + +func (s *searchStoreStub) SaveChunks(context.Context, []store.Chunk) error { return nil } +func (s *searchStoreStub) DeleteByFile(context.Context, string) error { return nil } +func (s *searchStoreStub) Search(context.Context, []float32, int, store.SearchOptions) ([]store.SearchResult, error) { + return nil, nil +} +func (s *searchStoreStub) GetDocument(context.Context, string) (*store.Document, error) { + return nil, nil +} +func (s *searchStoreStub) SaveDocument(context.Context, store.Document) error { return nil } +func (s *searchStoreStub) DeleteDocument(context.Context, string) error { return nil } +func (s *searchStoreStub) ListDocuments(context.Context) ([]string, error) { return nil, nil } +func (s *searchStoreStub) Load(context.Context) error { return nil } +func (s *searchStoreStub) Persist(context.Context) error { return nil } +func (s *searchStoreStub) Close() error { return nil } +func (s *searchStoreStub) GetStats(context.Context) (*store.IndexStats, error) { + return &store.IndexStats{}, nil +} +func (s *searchStoreStub) ListFilesWithStats(context.Context) ([]store.FileStats, error) { + return nil, nil +} +func (s *searchStoreStub) GetChunksForFile(context.Context, string) ([]store.Chunk, error) { + return nil, nil +} +func (s *searchStoreStub) GetAllChunks(context.Context) ([]store.Chunk, error) { return nil, nil } + +func TestSearchUsesQueryRoleWhenSupported(t *testing.T) { + emb := &roleAwareTestEmbedder{} + searcher := NewSearcher(&searchStoreStub{}, emb, config.SearchConfig{}) + + _, err := searcher.Search(context.Background(), "llama", 10, "") + if err != nil { + t.Fatalf("Search failed: %v", err) + } + if emb.lastRole != embedder.RoleQuery { + t.Fatalf("expected query role, got %s", emb.lastRole) + } + if emb.lastText != "llama" { + t.Fatalf("expected raw query text, got %q", emb.lastText) + } +}