diff --git a/CHANGELOG.md b/CHANGELOG.md index 22043b8..dd07662 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Voyage AI Provider**: Add Voyage AI as a new cloud embedding provider (`voyageai`) with code-optimized models like `voyage-code-3` (1024 dims), batch embedding, adaptive rate limiting, and full integration across CLI, MCP server, and workspaces - **`.grepaiignore` Support**: New `.grepaiignore` file allows overriding `.gitignore` rules for grepai indexing. Supports negation patterns (`!`) to re-include files excluded by `.gitignore`, with directory-level precedence for nested files (#107) ## [0.34.0] - 2026-02-24 diff --git a/README.md b/README.md index dd5004f..b6bcae3 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ curl -sSL https://raw.githubusercontent.com/yoanbernabeu/grepai/main/install.sh irm https://raw.githubusercontent.com/yoanbernabeu/grepai/main/install.ps1 | iex ``` -Requires an embedding provider — [Ollama](https://ollama.ai) (default), [LM Studio](https://lmstudio.ai), or OpenAI. +Requires an embedding provider — [Ollama](https://ollama.ai) (default), [LM Studio](https://lmstudio.ai), OpenAI, or [Voyage AI](https://voyageai.com). **Ollama (recommended):** ```bash diff --git a/cli/init.go b/cli/init.go index d157d45..ca9ee02 100644 --- a/cli/init.go +++ b/cli/init.go @@ -32,14 +32,14 @@ var initCmd = &cobra.Command{ This command will: - Create .grepai/config.yaml with default settings -- Prompt for embedding provider (Ollama or OpenAI) +- Prompt for embedding provider (Ollama, LM Studio, OpenAI, or Voyage AI) - Prompt for storage backend (GOB file or PostgreSQL) - Add .grepai/ to .gitignore if present`, RunE: runInit, } func init() { - initCmd.Flags().StringVarP(&initProvider, "provider", "p", "", "Embedding provider (ollama, lmstudio, openai, synthetic, or openrouter)") + initCmd.Flags().StringVarP(&initProvider, "provider", "p", "", "Embedding provider (ollama, lmstudio, openai, voyageai, synthetic, or openrouter)") initCmd.Flags().StringVarP(&initModel, "model", "m", "", "Embedding model (for openrouter: text-embedding-3-small, text-embedding-3-large, qwen3-embedding-8b)") initCmd.Flags().StringVarP(&initBackend, "backend", "b", "", "Storage backend (gob, postgres, or qdrant)") initCmd.Flags().BoolVar(&initNonInteractive, "yes", false, "Use defaults without prompting") @@ -123,6 +123,7 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Println(" 3) openai (cloud, requires API key)") fmt.Println(" 4) synthetic (cloud, free embedding API)") fmt.Println(" 5) openrouter (cloud, multi-provider gateway)") + fmt.Println(" 6) voyageai (cloud, optimized for code, requires API key)") fmt.Print("Choice [1]: ") input, _ := reader.ReadString('\n') @@ -145,7 +146,7 @@ func runInit(cmd *cobra.Command, args []string) error { cfg.Embedder.Provider = "openai" cfg.Embedder.Model = "text-embedding-3-small" cfg.Embedder.Endpoint = "https://api.openai.com/v1" - // OpenAI: leave Dimensions nil to use model's native dimensions + cfg.Embedder.Dimensions = nil // use model's native dimensions case "4", "synthetic": cfg.Embedder.Provider = "synthetic" cfg.Embedder.Model = "hf:nomic-ai/nomic-embed-text-v1.5" @@ -155,7 +156,7 @@ func runInit(cmd *cobra.Command, args []string) error { case "5", "openrouter": cfg.Embedder.Provider = "openrouter" cfg.Embedder.Endpoint = "https://openrouter.ai/api/v1" - // OpenRouter: leave Dimensions nil to use model's native dimensions + cfg.Embedder.Dimensions = nil // use model's native dimensions // Model selection for OpenRouter fmt.Println("\nSelect OpenRouter embedding model:") @@ -175,6 +176,11 @@ func runInit(cmd *cobra.Command, args []string) error { default: cfg.Embedder.Model = "openai/text-embedding-3-small" } + case "6", "voyageai": + cfg.Embedder.Provider = "voyageai" + cfg.Embedder.Model = "voyage-code-3" + cfg.Embedder.Endpoint = "https://api.voyageai.com/v1" + cfg.Embedder.Dimensions = nil // use model's native dimensions (1024) default: cfg.Embedder.Provider = "ollama" fmt.Print("Ollama endpoint [http://localhost:11434]: ") @@ -196,7 +202,11 @@ func runInit(cmd *cobra.Command, args []string) error { case "openai": cfg.Embedder.Model = "text-embedding-3-small" cfg.Embedder.Endpoint = "https://api.openai.com/v1" - // OpenAI: leave Dimensions nil to use model's native dimensions + cfg.Embedder.Dimensions = nil // use model's native dimensions + case "voyageai": + cfg.Embedder.Model = "voyage-code-3" + cfg.Embedder.Endpoint = "https://api.voyageai.com/v1" + cfg.Embedder.Dimensions = nil // use model's native dimensions (1024) case "synthetic": cfg.Embedder.Model = "hf:nomic-ai/nomic-embed-text-v1.5" cfg.Embedder.Endpoint = "https://api.synthetic.new/openai/v1" @@ -205,7 +215,7 @@ func runInit(cmd *cobra.Command, args []string) error { case "openrouter": cfg.Embedder.Model = "openai/text-embedding-3-small" cfg.Embedder.Endpoint = "https://openrouter.ai/api/v1" - // OpenRouter: leave Dimensions nil to use model's native dimensions + cfg.Embedder.Dimensions = nil // use model's native dimensions } } @@ -339,6 +349,8 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Printf(" Endpoint: %s\n", cfg.Embedder.Endpoint) case "openai": fmt.Println("\nMake sure OPENAI_API_KEY is set in your environment.") + case "voyageai": + fmt.Println("\nMake sure VOYAGE_API_KEY is set in your environment.") case "synthetic": fmt.Println("\nMake sure SYNTHETIC_API_KEY or OPENAI_API_KEY is set in your environment.") fmt.Println(" Get your free API key at: https://api.synthetic.new") diff --git a/cli/watch.go b/cli/watch.go index da3262c..ceb4d0a 100644 --- a/cli/watch.go +++ b/cli/watch.go @@ -424,6 +424,7 @@ func initializeEmbedder(ctx context.Context, cfg *config.Config) (embedder.Embed return nil, fmt.Errorf("cannot connect to Ollama: %w\nMake sure Ollama is running and has the %s model", err, cfg.Embedder.Model) } } + case "lmstudio": if p, ok := emb.(pinger); ok { if err := p.Ping(ctx); err != nil { diff --git a/config/config.go b/config/config.go index 22f3dc4..0d6c6ad 100644 --- a/config/config.go +++ b/config/config.go @@ -24,10 +24,12 @@ const ( DefaultOpenAIEmbeddingModel = "text-embedding-3-small" DefaultSyntheticEmbeddingModel = "hf:nomic-ai/nomic-embed-text-v1.5" DefaultOpenRouterEmbeddingModel = "openai/text-embedding-3-small" + DefaultVoyageAIEmbeddingModel = "voyage-code-3" DefaultOllamaEndpoint = "http://localhost:11434" DefaultLMStudioEndpoint = "http://127.0.0.1:1234" DefaultOpenAIEndpoint = "https://api.openai.com/v1" + DefaultVoyageAIEndpoint = "https://api.voyageai.com/v1" DefaultSyntheticEndpoint = "https://api.synthetic.new/openai/v1" DefaultOpenRouterEndpoint = "https://openrouter.ai/api/v1" @@ -93,16 +95,17 @@ type BoostRule struct { } type EmbedderConfig struct { - Provider string `yaml:"provider"` // ollama | lmstudio | openai | synthetic | openrouter + Provider string `yaml:"provider"` // ollama | lmstudio | openai | voyageai | synthetic | openrouter Model string `yaml:"model"` Endpoint string `yaml:"endpoint,omitempty"` APIKey string `yaml:"api_key,omitempty"` Dimensions *int `yaml:"dimensions,omitempty"` - Parallelism int `yaml:"parallelism"` // Number of parallel workers for batch embedding (default: 4) + Parallelism int `yaml:"parallelism,omitempty"` // Number of parallel workers for batch embedding (default: 4) } // GetDimensions returns the configured dimensions or a default value. // For OpenAI/OpenRouter, defaults to 1536 (text-embedding-3-small). +// For Voyage AI, defaults to 1024 (voyage-code-3). // For Ollama/LMStudio/Synthetic, defaults to 768 (nomic-embed-text-v1.5). func (e *EmbedderConfig) GetDimensions() int { if e.Dimensions != nil { @@ -111,6 +114,8 @@ func (e *EmbedderConfig) GetDimensions() int { switch e.Provider { case "openai", "openrouter": return DefaultOpenAIDimensions + case "voyageai": + return 1024 default: return DefaultLocalEmbeddingDimensions } @@ -118,6 +123,13 @@ func (e *EmbedderConfig) GetDimensions() int { func DefaultEmbedderForProvider(provider string) EmbedderConfig { switch provider { + case "voyageai": + return EmbedderConfig{ + Provider: "voyageai", + Model: DefaultVoyageAIEmbeddingModel, + Endpoint: DefaultVoyageAIEndpoint, + Dimensions: nil, // Voyage AI uses native dimensions (1024) + } case "synthetic": dim := DefaultLocalEmbeddingDimensions return EmbedderConfig{ @@ -448,7 +460,7 @@ func (c *Config) applyDefaults() { } } - // Parallelism default (only used by OpenAI embedder) + // Parallelism default (used by OpenAI and Voyage AI embedders) if c.Embedder.Parallelism <= 0 { c.Embedder.Parallelism = 4 } diff --git a/config/config_test.go b/config/config_test.go index c06ab29..31cfa6a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -83,6 +83,14 @@ func TestDefaultEmbedderForProvider(t *testing.T) { if openai.Dimensions != nil { t.Fatalf("openai dimensions should be nil, got %v", openai.Dimensions) } + + voyageai := DefaultEmbedderForProvider("voyageai") + if voyageai.Endpoint != DefaultVoyageAIEndpoint || voyageai.Model != DefaultVoyageAIEmbeddingModel { + t.Fatalf("unexpected voyageai defaults: %+v", voyageai) + } + if voyageai.Dimensions != nil { + t.Fatalf("voyageai dimensions should be nil, got %v", voyageai.Dimensions) + } } func TestDefaultStoreForBackend(t *testing.T) { diff --git a/docs/src/content/docs/backends/embedders.md b/docs/src/content/docs/backends/embedders.md index 3f61a45..149be18 100644 --- a/docs/src/content/docs/backends/embedders.md +++ b/docs/src/content/docs/backends/embedders.md @@ -12,6 +12,7 @@ Embedders convert text (code chunks) into vector representations that enable sem | Ollama | Local | Privacy, free, no internet | Requires local resources | | LM Studio | Local | Privacy, OpenAI-compatible API, GUI | Requires local resources | | OpenAI | Cloud | High quality, fast | Costs money, sends code to cloud | +| Voyage AI | Cloud | Optimized for code, high quality | Costs money, sends code to cloud | ## Ollama (Local) @@ -233,6 +234,59 @@ For a typical codebase: - Initial index: ~$0.001 with `text-embedding-3-small` - Ongoing updates: negligible +## Voyage AI (Cloud) + +[Voyage AI](https://voyageai.com) provides embedding models specifically optimized for code search and retrieval. + +### Setup + +1. Get an API key from [Voyage AI Dashboard](https://dash.voyageai.com/api-keys) + +2. Set the environment variable: + +```bash +export VOYAGE_API_KEY=pa-... +``` + +### Configuration + +```yaml +embedder: + provider: voyageai + model: voyage-code-3 + endpoint: https://api.voyageai.com/v1 + api_key: ${VOYAGE_API_KEY} +``` + +### Available Models + +| Model | Dimensions | Context | Notes | +|-------|------------|---------|-------| +| `voyage-code-3` | 1024 | 32K | Optimized for code retrieval (recommended) | +| `voyage-4-large` | 1024 | 32K | Best general-purpose retrieval quality | +| `voyage-4` | 1024 | 32K | Balanced quality and performance | +| `voyage-4-lite` | 1024 | 32K | Optimized for latency and cost | + +All Voyage 4 series models support flexible dimensions (256, 512, 1024, 2048) and share a compatible embedding space. + +### Parallelism & Rate Limiting + +Voyage AI embeddings support parallel batch processing with adaptive rate limiting: + +```yaml +embedder: + provider: voyageai + model: voyage-code-3 + api_key: ${VOYAGE_API_KEY} + parallelism: 4 # Concurrent API requests (default: 4) +``` + +Rate limiting works the same as OpenAI: on 429 responses, parallelism auto-reduces and retries with exponential backoff. + +### Cost Estimation + +See [Voyage AI Pricing](https://docs.voyageai.com/docs/pricing) for current rates. + ## Changing Embedding Models You can use any embedding model available on your provider. Two parameters matter: diff --git a/docs/src/content/docs/commands/grepai_init.md b/docs/src/content/docs/commands/grepai_init.md index 2665caa..e2cbf18 100644 --- a/docs/src/content/docs/commands/grepai_init.md +++ b/docs/src/content/docs/commands/grepai_init.md @@ -13,7 +13,7 @@ Initialize grepai by creating a .grepai directory with configuration. This command will: - Create .grepai/config.yaml with default settings -- Prompt for embedding provider (Ollama or OpenAI) +- Prompt for embedding provider (Ollama, OpenAI, Voyage AI, etc.) - Prompt for storage backend (GOB file or PostgreSQL) - Add .grepai/ to .gitignore if present @@ -27,7 +27,7 @@ grepai init [flags] -b, --backend string Storage backend (gob, postgres, or qdrant) -h, --help help for init --inherit Inherit configuration from main worktree (for git worktrees) - -p, --provider string Embedding provider (ollama, lmstudio, or openai) + -p, --provider string Embedding provider (ollama, lmstudio, openai, voyageai, synthetic, or openrouter) --ui Run interactive Bubble Tea UI wizard --yes Use defaults without prompting ``` diff --git a/docs/src/content/docs/configuration.md b/docs/src/content/docs/configuration.md index 1a6e502..e18c87d 100644 --- a/docs/src/content/docs/configuration.md +++ b/docs/src/content/docs/configuration.md @@ -17,7 +17,7 @@ version: 1 # Embedder configuration embedder: - # Provider: "ollama" (local), "lmstudio" (local), or "openai" (cloud) + # Provider: "ollama" (local), "lmstudio" (local), "openai" (cloud), or "voyageai" (cloud) provider: ollama # Model name (depends on provider) model: nomic-embed-text diff --git a/docs/src/content/docs/contributing.md b/docs/src/content/docs/contributing.md index 8525c9f..5b358e4 100644 --- a/docs/src/content/docs/contributing.md +++ b/docs/src/content/docs/contributing.md @@ -86,8 +86,10 @@ grepai/ ├── config/ # Configuration loading ├── embedder/ # Embedding providers │ ├── embedder.go # Interface +│ ├── factory.go # Provider factory │ ├── ollama.go -│ └── openai.go +│ ├── openai.go +│ └── voyageai.go ├── store/ # Vector storage │ ├── store.go # Interface │ ├── gob.go diff --git a/docs/src/content/docs/installation.md b/docs/src/content/docs/installation.md index 2bcf360..41133db 100644 --- a/docs/src/content/docs/installation.md +++ b/docs/src/content/docs/installation.md @@ -5,7 +5,7 @@ description: How to install grepai ## Prerequisites -- **Ollama** (for local embeddings) or an **OpenAI API key** (for cloud embeddings) +- **Ollama** (for local embeddings) or a cloud API key (**OpenAI**, **Voyage AI**, etc.) ## Homebrew (macOS) diff --git a/docs/src/content/docs/skills.md b/docs/src/content/docs/skills.md index f7549ef..d3f3f78 100644 --- a/docs/src/content/docs/skills.md +++ b/docs/src/content/docs/skills.md @@ -105,6 +105,7 @@ Install skills by category: | `grepai-embeddings-ollama` | Configure Ollama for local, private embeddings | | `grepai-embeddings-openai` | Configure OpenAI for cloud embeddings | | `grepai-embeddings-lmstudio` | Configure LM Studio with GUI interface | +| `grepai-embeddings-voyageai` | Configure Voyage AI for code-optimized cloud embeddings | ### Storage Backends | Skill | Description | diff --git a/docs/src/content/docs/workspace.md b/docs/src/content/docs/workspace.md index 067bcc0..21dac86 100644 --- a/docs/src/content/docs/workspace.md +++ b/docs/src/content/docs/workspace.md @@ -53,7 +53,7 @@ grepai workspace create my-fullstack --from workspace-config.yaml | Flag | Description | Default | |------|-------------|---------| | `--backend` | Storage backend (`qdrant` or `postgres`) | Required (or `--yes`) | -| `--provider` | Embedding provider (`ollama`, `openai`, `lmstudio`) | `ollama` with `--yes` | +| `--provider` | Embedding provider (`ollama`, `openai`, `lmstudio`, `voyageai`, `synthetic`, `openrouter`) | `ollama` with `--yes` | | `--model` | Embedding model name | Provider default | | `--endpoint` | Embedder endpoint URL | Provider default | | `--dsn` | PostgreSQL connection string | Required for postgres | diff --git a/embedder/batch.go b/embedder/batch.go index 93bd0d3..900178d 100644 --- a/embedder/batch.go +++ b/embedder/batch.go @@ -2,11 +2,28 @@ package embedder // MaxBatchSize is the maximum number of inputs per OpenAI embedding API call. // OpenAI allows 2048, but we use 2000 as a safety margin. -const MaxBatchSize = 2000 +const DefaultMaxBatchSize = 2000 // MaxBatchTokens is the maximum total tokens per OpenAI embedding API batch. // OpenAI has a 300,000 token limit. We use 280,000 for safety margin. -const MaxBatchTokens = 280000 +const DefaultMaxBatchTokens = 280000 + +// BatchConfig holds configurable limits for batch formation. +// Providers can override these defaults based on their API constraints. +type BatchConfig struct { + // MaxBatchSize is the maximum number of inputs per API call. + MaxBatchSize int + // MaxBatchTokens is the maximum total tokens per batch. + MaxBatchTokens int +} + +// DefaultBatchConfig returns a BatchConfig with default values. +func DefaultBatchConfig() BatchConfig { + return BatchConfig{ + MaxBatchSize: DefaultMaxBatchSize, + MaxBatchTokens: DefaultMaxBatchTokens, + } +} // EstimateTokens estimates the token count for a text string. // Uses a conservative estimate of ~4 characters per token for English text. @@ -59,26 +76,30 @@ type FileChunks struct { // batchBuilder accumulates chunks into batches. type batchBuilder struct { - batches []Batch - current Batch - currentTokens int + batches []Batch + current Batch + currentTokens int + maxBatchSize int + maxBatchTokens int } -func newBatchBuilder(estimatedBatches int) *batchBuilder { +func newBatchBuilder(estimatedBatches int, cfg BatchConfig) *batchBuilder { return &batchBuilder{ batches: make([]Batch, 0, estimatedBatches), current: Batch{ Index: 0, - Entries: make([]BatchEntry, 0, MaxBatchSize), + Entries: make([]BatchEntry, 0, cfg.MaxBatchSize), }, + maxBatchSize: cfg.MaxBatchSize, + maxBatchTokens: cfg.MaxBatchTokens, } } func (b *batchBuilder) isFull(additionalTokens int) bool { - if len(b.current.Entries) >= MaxBatchSize { + if len(b.current.Entries) >= b.maxBatchSize { return true } - if len(b.current.Entries) > 0 && b.currentTokens+additionalTokens > MaxBatchTokens { + if len(b.current.Entries) > 0 && b.currentTokens+additionalTokens > b.maxBatchTokens { return true } return false @@ -88,7 +109,7 @@ func (b *batchBuilder) finalizeCurrent() { b.batches = append(b.batches, b.current) b.current = Batch{ Index: len(b.batches), - Entries: make([]BatchEntry, 0, MaxBatchSize), + Entries: make([]BatchEntry, 0, b.maxBatchSize), } b.currentTokens = 0 } @@ -110,16 +131,28 @@ func (b *batchBuilder) build() []Batch { } // FormBatches splits chunks from multiple files into batches respecting both -// MaxBatchSize (input count) and MaxBatchTokens (token limit). +// MaxBatchSize (input count) and MaxBatchTokens (token limit) from the given config. +// If no config is provided, default values are used. // Chunks maintain their file/chunk index tracking for result mapping. -func FormBatches(files []FileChunks) []Batch { +func FormBatches(files []FileChunks, configs ...BatchConfig) []Batch { + cfg := DefaultBatchConfig() + if len(configs) > 0 { + cfg = configs[0] + if cfg.MaxBatchSize <= 0 { + cfg.MaxBatchSize = DefaultMaxBatchSize + } + if cfg.MaxBatchTokens <= 0 { + cfg.MaxBatchTokens = DefaultMaxBatchTokens + } + } + totalChunks := countTotalChunks(files) if totalChunks == 0 { return nil } - estimatedBatches := (totalChunks + MaxBatchSize - 1) / MaxBatchSize - builder := newBatchBuilder(estimatedBatches) + estimatedBatches := (totalChunks + cfg.MaxBatchSize - 1) / cfg.MaxBatchSize + builder := newBatchBuilder(estimatedBatches, cfg) for _, file := range files { for chunkIdx, chunk := range file.Chunks { diff --git a/embedder/batch_test.go b/embedder/batch_test.go index 218bbba..9068191 100644 --- a/embedder/batch_test.go +++ b/embedder/batch_test.go @@ -96,8 +96,8 @@ func TestFormBatches_SingleFileFewChunks(t *testing.T) { } func TestFormBatches_SingleFileManyChunks(t *testing.T) { - // Create file with more than MaxBatchSize chunks - chunks := make([]string, MaxBatchSize+500) + // Create file with more than DefaultMaxBatchSize chunks + chunks := make([]string, DefaultMaxBatchSize+500) for i := range chunks { chunks[i] = "chunk" } @@ -113,8 +113,8 @@ func TestFormBatches_SingleFileManyChunks(t *testing.T) { } // First batch should be full - if len(batches[0].Entries) != MaxBatchSize { - t.Errorf("first batch should have %d entries, got %d", MaxBatchSize, len(batches[0].Entries)) + if len(batches[0].Entries) != DefaultMaxBatchSize { + t.Errorf("first batch should have %d entries, got %d", DefaultMaxBatchSize, len(batches[0].Entries)) } if batches[0].Index != 0 { t.Errorf("first batch.Index = %d, expected 0", batches[0].Index) @@ -135,7 +135,7 @@ func TestFormBatches_SingleFileManyChunks(t *testing.T) { } } for i, entry := range batches[1].Entries { - expectedIdx := MaxBatchSize + i + expectedIdx := DefaultMaxBatchSize + i if entry.ChunkIndex != expectedIdx { t.Errorf("batch[1].entry[%d].ChunkIndex = %d, expected %d", i, entry.ChunkIndex, expectedIdx) } @@ -190,7 +190,7 @@ func TestFormBatches_MultipleFilesCombined(t *testing.T) { func TestFormBatches_MultipleFilesBatchBoundary(t *testing.T) { // Create files that will span batch boundaries - file1Chunks := make([]string, MaxBatchSize-100) + file1Chunks := make([]string, DefaultMaxBatchSize-100) for i := range file1Chunks { file1Chunks[i] = "file1" } @@ -211,8 +211,8 @@ func TestFormBatches_MultipleFilesBatchBoundary(t *testing.T) { } // First batch: all of file1 (1900) + first 100 of file2 - if len(batches[0].Entries) != MaxBatchSize { - t.Errorf("first batch should have %d entries, got %d", MaxBatchSize, len(batches[0].Entries)) + if len(batches[0].Entries) != DefaultMaxBatchSize { + t.Errorf("first batch should have %d entries, got %d", DefaultMaxBatchSize, len(batches[0].Entries)) } // Second batch: remaining 100 of file2 @@ -246,8 +246,8 @@ func TestFormBatches_MultipleFilesBatchBoundary(t *testing.T) { } } -func TestFormBatches_ExactlyMaxBatchSize(t *testing.T) { - chunks := make([]string, MaxBatchSize) +func TestFormBatches_ExactlyDefaultMaxBatchSize(t *testing.T) { + chunks := make([]string, DefaultMaxBatchSize) for i := range chunks { chunks[i] = "chunk" } @@ -259,15 +259,15 @@ func TestFormBatches_ExactlyMaxBatchSize(t *testing.T) { batches := FormBatches(files) if len(batches) != 1 { - t.Errorf("expected 1 batch for exactly %d chunks, got %d", MaxBatchSize, len(batches)) + t.Errorf("expected 1 batch for exactly %d chunks, got %d", DefaultMaxBatchSize, len(batches)) } - if len(batches[0].Entries) != MaxBatchSize { - t.Errorf("batch should have %d entries, got %d", MaxBatchSize, len(batches[0].Entries)) + if len(batches[0].Entries) != DefaultMaxBatchSize { + t.Errorf("batch should have %d entries, got %d", DefaultMaxBatchSize, len(batches[0].Entries)) } } -func TestFormBatches_ExactlyMaxBatchSizePlusOne(t *testing.T) { - chunks := make([]string, MaxBatchSize+1) +func TestFormBatches_ExactlyDefaultMaxBatchSizePlusOne(t *testing.T) { + chunks := make([]string, DefaultMaxBatchSize+1) for i := range chunks { chunks[i] = "chunk" } @@ -279,10 +279,10 @@ func TestFormBatches_ExactlyMaxBatchSizePlusOne(t *testing.T) { batches := FormBatches(files) if len(batches) != 2 { - t.Errorf("expected 2 batches for %d chunks, got %d", MaxBatchSize+1, len(batches)) + t.Errorf("expected 2 batches for %d chunks, got %d", DefaultMaxBatchSize+1, len(batches)) } - if len(batches[0].Entries) != MaxBatchSize { - t.Errorf("first batch should have %d entries, got %d", MaxBatchSize, len(batches[0].Entries)) + if len(batches[0].Entries) != DefaultMaxBatchSize { + t.Errorf("first batch should have %d entries, got %d", DefaultMaxBatchSize, len(batches[0].Entries)) } if len(batches[1].Entries) != 1 { t.Errorf("second batch should have 1 entry, got %d", len(batches[1].Entries)) @@ -409,7 +409,7 @@ func TestEstimateTokens(t *testing.T) { func TestFormBatches_TokenLimit(t *testing.T) { // Create chunks that are large enough to trigger token limit // Each chunk will be ~10000 chars -> ~2500 tokens - // With MaxBatchTokens = 280000, we can fit ~112 such chunks per batch + // With DefaultMaxBatchTokens = 280000, we can fit ~112 such chunks per batch largeChunk := string(make([]byte, 10000)) // Create 200 large chunks - should be split into multiple batches by token limit @@ -424,7 +424,7 @@ func TestFormBatches_TokenLimit(t *testing.T) { batches := FormBatches(files) - // Should have more than 1 batch due to token limit (even though count is below MaxBatchSize) + // Should have more than 1 batch due to token limit (even though count is below DefaultMaxBatchSize) if len(batches) < 2 { t.Errorf("expected multiple batches due to token limit, got %d", len(batches)) } @@ -444,17 +444,17 @@ func TestFormBatches_TokenLimit(t *testing.T) { for _, entry := range batch.Entries { batchTokens += EstimateTokens(entry.Content) } - if batchTokens > MaxBatchTokens { - t.Errorf("batch %d has %d tokens, exceeds MaxBatchTokens %d", i, batchTokens, MaxBatchTokens) + if batchTokens > DefaultMaxBatchTokens { + t.Errorf("batch %d has %d tokens, exceeds DefaultMaxBatchTokens %d", i, batchTokens, DefaultMaxBatchTokens) } } } func TestFormBatches_SmallChunksIgnoreTokenLimit(t *testing.T) { - // With small chunks, we should hit the count limit (MaxBatchSize) before token limit + // With small chunks, we should hit the count limit (DefaultMaxBatchSize) before token limit smallChunk := "hello" - chunks := make([]string, MaxBatchSize+100) + chunks := make([]string, DefaultMaxBatchSize+100) for i := range chunks { chunks[i] = smallChunk } @@ -470,8 +470,8 @@ func TestFormBatches_SmallChunksIgnoreTokenLimit(t *testing.T) { t.Errorf("expected 2 batches (split by count), got %d", len(batches)) } - // First batch should be exactly MaxBatchSize - if len(batches[0].Entries) != MaxBatchSize { - t.Errorf("first batch should have %d entries, got %d", MaxBatchSize, len(batches[0].Entries)) + // First batch should be exactly DefaultMaxBatchSize + if len(batches[0].Entries) != DefaultMaxBatchSize { + t.Errorf("first batch should have %d entries, got %d", DefaultMaxBatchSize, len(batches[0].Entries)) } } diff --git a/embedder/embedder.go b/embedder/embedder.go index d72f762..f94be90 100644 --- a/embedder/embedder.go +++ b/embedder/embedder.go @@ -33,4 +33,8 @@ type BatchEmbedder interface { // It returns results mapped back to their source files, or an error if any batch fails. // The progress callback is called for each batch completion or retry attempt. EmbedBatches(ctx context.Context, batches []Batch, progress BatchProgress) ([]BatchResult, error) + + // BatchConfig returns the provider-specific batch configuration. + // Each provider may have different API limits for batch size and token count. + BatchConfig() BatchConfig } diff --git a/embedder/embedder_test.go b/embedder/embedder_test.go index 2a54705..434a0b6 100644 --- a/embedder/embedder_test.go +++ b/embedder/embedder_test.go @@ -1,6 +1,8 @@ package embedder import ( + "encoding/json" + "strings" "testing" ) @@ -294,6 +296,210 @@ func TestEmbedder_Close(t *testing.T) { t.Errorf("Close() returned error: %v", err) } }) + + t.Run("VoyageAIEmbedder", func(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + e, err := NewVoyageAIEmbedder() + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + if err := e.Close(); err != nil { + t.Errorf("Close() returned error: %v", err) + } + }) +} + +// Test VoyageAIEmbedder options +func TestNewVoyageAIEmbedder_Defaults(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + + e, err := NewVoyageAIEmbedder() + if err != nil { + t.Fatalf("failed to create VoyageAIEmbedder: %v", err) + } + + if e.endpoint != defaultVoyageAIEndpoint { + t.Errorf("expected endpoint %s, got %s", defaultVoyageAIEndpoint, e.endpoint) + } + + if e.model != defaultVoyageAIModel { + t.Errorf("expected model %s, got %s", defaultVoyageAIModel, e.model) + } + + // dimensions should be nil by default (no output_dimension param sent to API) + if e.dimensions != nil { + t.Errorf("expected nil dimensions, got %v", e.dimensions) + } +} + +func TestNewVoyageAIEmbedder_WithOptions(t *testing.T) { + customEndpoint := "https://custom-voyage.example.com/v1" + customModel := "voyage-3" + customKey := "va-custom-key" + customDimensions := 512 + + e, err := NewVoyageAIEmbedder( + WithVoyageAIEndpoint(customEndpoint), + WithVoyageAIModel(customModel), + WithVoyageAIKey(customKey), + WithVoyageAIDimensions(customDimensions), + WithVoyageAIInputType("document"), + ) + if err != nil { + t.Fatalf("failed to create VoyageAIEmbedder: %v", err) + } + + if e.endpoint != customEndpoint { + t.Errorf("expected endpoint %s, got %s", customEndpoint, e.endpoint) + } + + if e.model != customModel { + t.Errorf("expected model %s, got %s", customModel, e.model) + } + + if e.apiKey != customKey { + t.Errorf("expected apiKey %s, got %s", customKey, e.apiKey) + } + + if e.dimensions == nil || *e.dimensions != customDimensions { + t.Errorf("expected dimensions %d, got %v", customDimensions, e.dimensions) + } + + if e.inputType != "document" { + t.Errorf("expected inputType 'document', got %s", e.inputType) + } +} + +func TestNewVoyageAIEmbedder_RequiresAPIKey(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "") + + _, err := NewVoyageAIEmbedder() + if err == nil { + t.Fatal("expected error when API key is not set") + } +} + +func TestNewVoyageAIEmbedder_UsesEnvAPIKey(t *testing.T) { + envKey := "va-env-test-key" + t.Setenv("VOYAGE_API_KEY", envKey) + + e, err := NewVoyageAIEmbedder() + if err != nil { + t.Fatalf("failed to create VoyageAIEmbedder: %v", err) + } + + if e.apiKey != envKey { + t.Errorf("expected apiKey from env %s, got %s", envKey, e.apiKey) + } +} + +func TestNewVoyageAIEmbedder_ExplicitKeyOverridesEnv(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "env-key") + explicitKey := "va-explicit-key" + + e, err := NewVoyageAIEmbedder(WithVoyageAIKey(explicitKey)) + if err != nil { + t.Fatalf("failed to create VoyageAIEmbedder: %v", err) + } + + if e.apiKey != explicitKey { + t.Errorf("expected explicit apiKey %s, got %s", explicitKey, e.apiKey) + } +} + +func TestVoyageAIEmbedder_Dimensions(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + + tests := []struct { + name string + dimensions int + }{ + {"default", defaultVoyageAIDimensions}, + {"custom 256", 256}, + {"custom 512", 512}, + {"custom 2048", 2048}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var e *VoyageAIEmbedder + var err error + if tt.dimensions == defaultVoyageAIDimensions { + e, err = NewVoyageAIEmbedder() + } else { + e, err = NewVoyageAIEmbedder(WithVoyageAIDimensions(tt.dimensions)) + } + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + if e.Dimensions() != tt.dimensions { + t.Errorf("expected Dimensions() to return %d, got %d", tt.dimensions, e.Dimensions()) + } + }) + } +} + +func TestVoyageAIEmbedder_RequestUsesOutputDimension(t *testing.T) { + // Verify the JSON request body uses "output_dimension" not "dimensions" + dimensions := 512 + req := voyageAIEmbedRequest{ + Model: "voyage-code-3", + Input: []string{"test"}, + OutputDimension: &dimensions, + InputType: "document", + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + + jsonStr := string(data) + + // Must contain "output_dimension", not "dimensions" + if !strings.Contains(jsonStr, `"output_dimension"`) { + t.Errorf("expected JSON to contain 'output_dimension', got: %s", jsonStr) + } + + // Must NOT contain bare "dimensions" key (only "output_dimension") + // Parse as generic map to check exact keys + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if _, ok := parsed["dimensions"]; ok { + t.Errorf("JSON should not contain 'dimensions' key, got: %s", jsonStr) + } + + if val, ok := parsed["output_dimension"]; !ok { + t.Errorf("JSON should contain 'output_dimension' key, got: %s", jsonStr) + } else if int(val.(float64)) != dimensions { + t.Errorf("expected output_dimension=%d, got %v", dimensions, val) + } +} + +func TestVoyageAIEmbedder_RequestOmitsNilDimension(t *testing.T) { + // When dimensions is nil, output_dimension should be omitted (omitempty) + req := voyageAIEmbedRequest{ + Model: "voyage-code-3", + Input: []string{"test"}, + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if _, ok := parsed["output_dimension"]; ok { + t.Errorf("output_dimension should be omitted when nil, got: %s", string(data)) + } } // Test endpoint option combinations @@ -365,3 +571,28 @@ func TestOpenAIEmbedder_EndpointVariants(t *testing.T) { }) } } + +func TestVoyageAIEmbedder_EndpointVariants(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + + tests := []struct { + name string + endpoint string + }{ + {"default", "https://api.voyageai.com/v1"}, + {"custom", "https://custom-voyage-proxy.example.com/v1"}, + {"local proxy", "http://localhost:8080/v1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e, err := NewVoyageAIEmbedder(WithVoyageAIEndpoint(tt.endpoint)) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + if e.endpoint != tt.endpoint { + t.Errorf("expected endpoint %s, got %s", tt.endpoint, e.endpoint) + } + }) + } +} diff --git a/embedder/factory.go b/embedder/factory.go index a200a43..da2b1c7 100644 --- a/embedder/factory.go +++ b/embedder/factory.go @@ -54,6 +54,18 @@ func NewFromConfig(cfg *config.Config) (Embedder, error) { } return NewSyntheticEmbedder(opts...) + case "voyageai": + opts := []VoyageAIOption{ + WithVoyageAIModel(cfg.Embedder.Model), + WithVoyageAIKey(cfg.Embedder.APIKey), + WithVoyageAIEndpoint(cfg.Embedder.Endpoint), + WithVoyageAIParallelism(cfg.Embedder.Parallelism), + } + if cfg.Embedder.Dimensions != nil { + opts = append(opts, WithVoyageAIDimensions(*cfg.Embedder.Dimensions)) + } + return NewVoyageAIEmbedder(opts...) + case "openrouter": opts := []OpenRouterOption{ WithOpenRouterModel(cfg.Embedder.Model), diff --git a/embedder/factory_test.go b/embedder/factory_test.go index a0e18bc..346cf4a 100644 --- a/embedder/factory_test.go +++ b/embedder/factory_test.go @@ -122,6 +122,30 @@ func TestNewFromConfig_OpenRouter(t *testing.T) { } } +func TestNewFromConfig_VoyageAI(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + + cfg := &config.Config{ + Embedder: config.EmbedderConfig{ + Provider: "voyageai", + Model: "voyage-code-3", + Endpoint: "https://api.voyageai.com/v1", + Parallelism: 4, + }, + } + + emb, err := NewFromConfig(cfg) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + defer emb.Close() + + _, ok := emb.(*VoyageAIEmbedder) + if !ok { + t.Errorf("expected *VoyageAIEmbedder, got %T", emb) + } +} + func TestNewFromConfig_UnknownProvider(t *testing.T) { cfg := &config.Config{ Embedder: config.EmbedderConfig{ diff --git a/embedder/openai.go b/embedder/openai.go index 79ac0e0..5639b4d 100644 --- a/embedder/openai.go +++ b/embedder/openai.go @@ -225,6 +225,15 @@ func (e *OpenAIEmbedder) Close() error { return nil } +// BatchConfig returns batch limits tuned for the OpenAI embeddings API. +// OpenAI allows up to 2048 inputs and ~300k tokens per batch. +func (e *OpenAIEmbedder) BatchConfig() BatchConfig { + return BatchConfig{ + MaxBatchSize: 2000, + MaxBatchTokens: 280000, + } +} + // EmbedBatches implements the BatchEmbedder interface. // It processes multiple batches concurrently using a bounded worker pool // and retries failed requests with exponential backoff. diff --git a/embedder/voyageai.go b/embedder/voyageai.go new file mode 100644 index 0000000..e640d19 --- /dev/null +++ b/embedder/voyageai.go @@ -0,0 +1,447 @@ +package embedder + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync/atomic" + "time" + + "golang.org/x/sync/errgroup" +) + +const ( + defaultVoyageAIEndpoint = "https://api.voyageai.com/v1" + defaultVoyageAIModel = "voyage-code-3" + defaultVoyageAIDimensions = 1024 +) + +type VoyageAIEmbedder struct { + endpoint string + model string + apiKey string + dimensions *int + inputType string // optional: "query", "document", or "" (none) + parallelism int + retryPolicy RetryPolicy + client *http.Client + rateLimiter *AdaptiveRateLimiter + tokenBucket *TokenBucket + tpmLimit int64 // Tokens per minute limit (0 = disabled) +} + +type voyageAIEmbedRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + OutputDimension *int `json:"output_dimension,omitempty"` + InputType string `json:"input_type,omitempty"` +} + +// voyageAIErrorResponse shares the same structure as OpenAI. +type voyageAIErrorResponse = openAIErrorResponse + +type VoyageAIOption func(*VoyageAIEmbedder) + +func WithVoyageAIEndpoint(endpoint string) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.endpoint = endpoint + } +} + +func WithVoyageAIModel(model string) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.model = model + } +} + +func WithVoyageAIKey(key string) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.apiKey = key + } +} + +func WithVoyageAIDimensions(dimensions int) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.dimensions = &dimensions + } +} + +// WithVoyageAIInputType sets the input_type parameter for the Voyage AI API. +// Options: "query", "document", or "" (unspecified). +func WithVoyageAIInputType(inputType string) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.inputType = inputType + } +} + +func WithVoyageAIParallelism(parallelism int) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + if parallelism > 0 { + e.parallelism = parallelism + } + } +} + +func WithVoyageAIRetryPolicy(policy RetryPolicy) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + e.retryPolicy = policy + } +} + +// WithVoyageAITPMLimit sets the tokens-per-minute limit for proactive rate limiting. +// When set > 0, the embedder will pace requests to stay within this limit. +func WithVoyageAITPMLimit(tpm int64) VoyageAIOption { + return func(e *VoyageAIEmbedder) { + if tpm > 0 { + e.tpmLimit = tpm + } + } +} + +func NewVoyageAIEmbedder(opts ...VoyageAIOption) (*VoyageAIEmbedder, error) { + e := &VoyageAIEmbedder{ + endpoint: defaultVoyageAIEndpoint, + model: defaultVoyageAIModel, + dimensions: nil, // nil = let the model use its native dimensions + parallelism: defaultParallelism, + retryPolicy: DefaultRetryPolicy(), + client: &http.Client{ + Timeout: 60 * time.Second, + }, + } + + for _, opt := range opts { + opt(e) + } + + // Try to get API key from environment if not set + if e.apiKey == "" { + e.apiKey = os.Getenv("VOYAGE_API_KEY") + } + + if e.apiKey == "" { + return nil, fmt.Errorf("Voyage AI API key not set (use VOYAGE_API_KEY environment variable)") + } + + // Initialize adaptive rate limiter with configured parallelism + e.rateLimiter = NewAdaptiveRateLimiter(e.parallelism) + + // Initialize token bucket if TPM limit is set + if e.tpmLimit > 0 { + e.tokenBucket = NewTokenBucket(e.tpmLimit) + } + + return e, nil +} + +func (e *VoyageAIEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + embeddings, err := e.EmbedBatch(ctx, []string{text}) + if err != nil { + return nil, err + } + return embeddings[0], nil +} + +func (e *VoyageAIEmbedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + reqBody := voyageAIEmbedRequest{ + Model: e.model, + Input: texts, + OutputDimension: e.dimensions, + InputType: e.inputType, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/embeddings", e.endpoint) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", e.apiKey)) + + resp, err := e.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request to Voyage AI: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, handleVoyageAIErrorResponse(resp, body) + } + + return parseEmbeddingsResponse(body, len(texts)) +} + +func (e *VoyageAIEmbedder) Dimensions() int { + if e.dimensions == nil { + return defaultVoyageAIDimensions + } + return *e.dimensions +} + +func (e *VoyageAIEmbedder) Close() error { + return nil +} + +// BatchConfig returns batch limits tuned for the Voyage AI embeddings API. +func (e *VoyageAIEmbedder) BatchConfig() BatchConfig { + return BatchConfig{ + MaxBatchSize: 500, + MaxBatchTokens: 50000, + } +} + +// EmbedBatches implements the BatchEmbedder interface. +// It processes multiple batches concurrently using a bounded worker pool +// and retries failed requests with exponential backoff. +func (e *VoyageAIEmbedder) EmbedBatches(ctx context.Context, batches []Batch, progress BatchProgress) ([]BatchResult, error) { + if len(batches) == 0 { + return nil, nil + } + + // Calculate total chunks across all batches for progress tracking + totalChunks := 0 + for _, batch := range batches { + totalChunks += batch.Size() + } + + // Track completed chunks atomically for thread-safe progress updates + var completedChunks atomic.Int64 + + results := make([]BatchResult, len(batches)) + g, ctx := errgroup.WithContext(ctx) + // Use adaptive rate limiter's current workers for dynamic parallelism + g.SetLimit(e.rateLimiter.CurrentWorkers()) + + for i := range batches { + batch := batches[i] + g.Go(func() error { + embeddings, err := e.embedBatchWithRetry(ctx, batch, len(batches), totalChunks, &completedChunks, progress) + if err != nil { + return err + } + results[batch.Index] = BatchResult{ + BatchIndex: batch.Index, + Embeddings: embeddings, + } + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + + return results, nil +} + +// waitForTokenBucket waits for token budget if proactive rate limiting is enabled. +// Returns an error if the context is canceled while waiting. +func (e *VoyageAIEmbedder) waitForTokenBucket(ctx context.Context, tokens int64) error { + if e.tokenBucket == nil { + return nil + } + wait := e.tokenBucket.WaitForTokens(tokens) + if wait <= 0 { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(wait): + return nil + } +} + +// calculateRetryDelay determines the delay before the next retry attempt. +// Uses Retry-After header if available, otherwise falls back to exponential backoff. +func (e *VoyageAIEmbedder) calculateRetryDelay(attempt int, retryErr *RetryableError) time.Duration { + if retryErr.RateLimitHeaders != nil && retryErr.RateLimitHeaders.RetryAfter > 0 { + delay := retryErr.RateLimitHeaders.RetryAfter + if delay > 60*time.Second { + delay = 60 * time.Second + } + return delay + } + return e.retryPolicy.Calculate(attempt) +} + +// reportBatchSuccess handles successful batch completion: +// notifies rate limiter, tracks token usage, updates progress. +func (e *VoyageAIEmbedder) reportBatchSuccess( + batch Batch, + totalBatches int, + totalChunks int, + completedChunks *atomic.Int64, + estimatedTokens int64, + progress BatchProgress, +) { + e.rateLimiter.OnSuccess() + + if e.tokenBucket != nil { + e.tokenBucket.AddTokens(estimatedTokens) + } + + newCompleted := completedChunks.Add(int64(batch.Size())) + if progress != nil { + progress(batch.Index, totalBatches, int(newCompleted), totalChunks, false, 0, 0) + } +} + +// estimateBatchTokens returns the estimated token count for a batch. +func (e *VoyageAIEmbedder) estimateBatchTokens(contents []string) int64 { + if e.tokenBucket == nil { + return 0 + } + var total int64 + for _, content := range contents { + total += int64(EstimateTokens(content)) + } + return total +} + +func (e *VoyageAIEmbedder) embedBatchWithRetry( + ctx context.Context, + batch Batch, + totalBatches int, + totalChunks int, + completedChunks *atomic.Int64, + progress BatchProgress, +) ([][]float32, error) { + contents := batch.Contents() + estimatedTokens := e.estimateBatchTokens(contents) + + for attempt := 0; ; attempt++ { + if err := e.waitForTokenBucket(ctx, estimatedTokens); err != nil { + return nil, err + } + + embeddings, err := e.embedBatchRequest(ctx, contents) + if err == nil { + e.reportBatchSuccess(batch, totalBatches, totalChunks, completedChunks, estimatedTokens, progress) + return embeddings, nil + } + + retryErr, isRetryable := err.(*RetryableError) + if !isRetryable || !retryErr.Retryable { + return nil, err + } + + if retryErr.StatusCode == 429 { + e.rateLimiter.OnRateLimitHit() + } + + if !e.retryPolicy.ShouldRetry(attempt) { + return nil, fmt.Errorf("batch %d failed after %d attempts: %w", batch.Index, attempt+1, err) + } + + if progress != nil { + progress(batch.Index, totalBatches, int(completedChunks.Load()), totalChunks, true, attempt+1, retryErr.StatusCode) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(e.calculateRetryDelay(attempt, retryErr)): + } + } +} + +// buildEmbedHTTPRequest creates an HTTP request for the Voyage AI embeddings API. +func (e *VoyageAIEmbedder) buildEmbedHTTPRequest(ctx context.Context, texts []string) (*http.Request, error) { + reqBody := voyageAIEmbedRequest{ + Model: e.model, + Input: texts, + OutputDimension: e.dimensions, + InputType: e.inputType, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/embeddings", e.endpoint) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", e.apiKey)) + + return req, nil +} + +// handleVoyageAIErrorResponse parses an error response and returns an appropriate error. +func handleVoyageAIErrorResponse(resp *http.Response, body []byte) error { + var errResp voyageAIErrorResponse + msg := string(body) + if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" { + msg = errResp.Error.Message + } + + // Check for context length error + if resp.StatusCode == http.StatusBadRequest && + (strings.Contains(msg, "maximum context length") || + strings.Contains(msg, "too many tokens") || + strings.Contains(msg, "reduce the length") || + strings.Contains(msg, "total number of tokens")) { + return NewContextLengthError(0, 0, 32000, msg) + } + + retryErr := NewRetryableError(resp.StatusCode, fmt.Sprintf("Voyage AI API error (status %d): %s", resp.StatusCode, msg)) + if resp.StatusCode == http.StatusTooManyRequests { + headers := parseRateLimitHeaders(resp.Header) + retryErr.RateLimitHeaders = &headers + } + + return retryErr +} + +// embedBatchRequest makes a single embedding request to the Voyage AI API. +// It returns a RetryableError for HTTP errors that can be retried. +func (e *VoyageAIEmbedder) embedBatchRequest(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + req, err := e.buildEmbedHTTPRequest(ctx, texts) + if err != nil { + return nil, err + } + + resp, err := e.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request to Voyage AI: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, handleVoyageAIErrorResponse(resp, body) + } + + return parseEmbeddingsResponse(body, len(texts)) +} diff --git a/embedder/voyageai_batch_test.go b/embedder/voyageai_batch_test.go new file mode 100644 index 0000000..523eb6e --- /dev/null +++ b/embedder/voyageai_batch_test.go @@ -0,0 +1,760 @@ +package embedder + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestVoyageAIEmbedder_EmbedBatches_ParallelismLimit(t *testing.T) { + var ( + maxConcurrent int32 + current int32 + mu sync.Mutex + requestCount int32 + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := atomic.AddInt32(¤t, 1) + defer atomic.AddInt32(¤t, -1) + + mu.Lock() + if c > maxConcurrent { + maxConcurrent = c + } + mu.Unlock() + + atomic.AddInt32(&requestCount, 1) + + time.Sleep(50 * time.Millisecond) + + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + parallelism := 2 + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(parallelism), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := make([]Batch, 4) + for i := range batches { + batches[i] = Batch{ + Index: i, + Entries: []BatchEntry{ + {FileIndex: i, ChunkIndex: 0, Content: "test content"}, + }, + } + } + + ctx := context.Background() + results, err := e.EmbedBatches(ctx, batches, nil) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + if len(results) != len(batches) { + t.Errorf("expected %d results, got %d", len(batches), len(results)) + } + + if maxConcurrent > int32(parallelism) { + t.Errorf("max concurrent %d exceeded parallelism limit %d", maxConcurrent, parallelism) + } + + if atomic.LoadInt32(&requestCount) != int32(len(batches)) { + t.Errorf("expected %d requests, got %d", len(batches), requestCount) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_RequestFormat(t *testing.T) { + // Verify that the actual HTTP request uses "output_dimension" not "dimensions" + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := make([]byte, r.ContentLength) + r.Body.Read(body) + capturedBody = body + + resp := mockEmbeddingResponse(1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + dims := 512 + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIDimensions(dims), + WithVoyageAIInputType("document"), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + ctx := context.Background() + _, err = e.EmbedBatches(ctx, batches, nil) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + // Parse the captured request body + var parsed map[string]interface{} + if err := json.Unmarshal(capturedBody, &parsed); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + + // Must have "output_dimension", not "dimensions" + if _, ok := parsed["dimensions"]; ok { + t.Errorf("request body should NOT contain 'dimensions', got: %s", string(capturedBody)) + } + + if val, ok := parsed["output_dimension"]; !ok { + t.Errorf("request body should contain 'output_dimension', got: %s", string(capturedBody)) + } else if int(val.(float64)) != dims { + t.Errorf("expected output_dimension=%d, got %v", dims, val) + } + + // Check input_type is set + if val, ok := parsed["input_type"]; !ok { + t.Errorf("request body should contain 'input_type', got: %s", string(capturedBody)) + } else if val != "document" { + t.Errorf("expected input_type='document', got %v", val) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_ResultMapping(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{ + {FileIndex: 0, ChunkIndex: 0, Content: "file0 chunk0"}, + {FileIndex: 0, ChunkIndex: 1, Content: "file0 chunk1"}, + {FileIndex: 1, ChunkIndex: 0, Content: "file1 chunk0"}, + }, + }, + { + Index: 1, + Entries: []BatchEntry{ + {FileIndex: 1, ChunkIndex: 1, Content: "file1 chunk1"}, + {FileIndex: 2, ChunkIndex: 0, Content: "file2 chunk0"}, + }, + }, + } + + ctx := context.Background() + results, err := e.EmbedBatches(ctx, batches, nil) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + if len(results) != len(batches) { + t.Errorf("expected %d results, got %d", len(batches), len(results)) + } + + for _, result := range results { + expectedCount := len(batches[result.BatchIndex].Entries) + if len(result.Embeddings) != expectedCount { + t.Errorf("batch %d: expected %d embeddings, got %d", + result.BatchIndex, expectedCount, len(result.Embeddings)) + } + } + + fileEmbeddings := MapResultsToFiles(batches, results, 3) + if len(fileEmbeddings) != 3 { + t.Errorf("expected 3 file embeddings, got %d", len(fileEmbeddings)) + } + + if len(fileEmbeddings[0]) != 2 { + t.Errorf("file 0: expected 2 chunks, got %d", len(fileEmbeddings[0])) + } + + if len(fileEmbeddings[1]) != 2 { + t.Errorf("file 1: expected 2 chunks, got %d", len(fileEmbeddings[1])) + } + + if len(fileEmbeddings[2]) != 1 { + t.Errorf("file 2: expected 1 chunk, got %d", len(fileEmbeddings[2])) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_RetryOn429(t *testing.T) { + var requestCount int32 + rateLimitUntil := int32(2) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&requestCount, 1) + + if count <= rateLimitUntil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]string{ + "message": "Rate limit exceeded", + "type": "rate_limit_error", + }, + }) + return + } + + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + fastRetryPolicy := RetryPolicy{ + BaseDelay: 10 * time.Millisecond, + Multiplier: 2.0, + MaxDelay: 100 * time.Millisecond, + MaxAttempts: 5, + } + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(1), + WithVoyageAIRetryPolicy(fastRetryPolicy), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + var retryCount int32 + progress := func(batchIndex, totalBatches, completedChunks, totalChunks int, retrying bool, attempt int, statusCode int) { + if retrying { + atomic.AddInt32(&retryCount, 1) + } + } + + ctx := context.Background() + results, err := e.EmbedBatches(ctx, batches, progress) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + if len(results) != 1 { + t.Errorf("expected 1 result, got %d", len(results)) + } + + if atomic.LoadInt32(&retryCount) != 2 { + t.Errorf("expected 2 retries, got %d", retryCount) + } + + if atomic.LoadInt32(&requestCount) != 3 { + t.Errorf("expected 3 requests, got %d", requestCount) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_FailOn4xx(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]string{ + "message": "Invalid API key", + "type": "invalid_request_error", + }, + }) + })) + defer server.Close() + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("invalid-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + ctx := context.Background() + _, err = e.EmbedBatches(ctx, batches, nil) + if err == nil { + t.Fatal("expected error for 401 response") + } + + retryErr, ok := err.(*RetryableError) + if !ok { + t.Fatalf("expected RetryableError, got %T", err) + } + if retryErr.Retryable { + t.Error("401 error should not be retryable") + } +} + +func TestVoyageAIEmbedder_EmbedBatches_ContextCancellation(t *testing.T) { + requestStarted := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(requestStarted) + time.Sleep(5 * time.Second) + + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error, 1) + go func() { + _, err := e.EmbedBatches(ctx, batches, nil) + errChan <- err + }() + + <-requestStarted + cancel() + + select { + case err := <-errChan: + if err == nil { + t.Error("expected error after context cancellation") + } + case <-time.After(2 * time.Second): + t.Error("timeout waiting for cancellation") + } +} + +func TestVoyageAIEmbedder_EmbedBatches_EmptyInput(t *testing.T) { + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + ctx := context.Background() + results, err := e.EmbedBatches(ctx, nil, nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if results != nil { + t.Errorf("expected nil results for empty input, got %v", results) + } + + results, err = e.EmbedBatches(ctx, []Batch{}, nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if results != nil { + t.Errorf("expected nil results for empty input, got %v", results) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_ProgressCallback(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(1), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := make([]Batch, 3) + for i := range batches { + batches[i] = Batch{ + Index: i, + Entries: []BatchEntry{{FileIndex: i, ChunkIndex: 0, Content: "test"}}, + } + } + + type progressInfo struct { + batchIndex int + totalBatches int + completedChunks int + totalChunks int + retrying bool + attempt int + } + var progressCalls []progressInfo + var mu sync.Mutex + progress := func(batchIndex, totalBatches, completedChunks, totalChunks int, retrying bool, attempt int, statusCode int) { + mu.Lock() + progressCalls = append(progressCalls, progressInfo{ + batchIndex: batchIndex, + totalBatches: totalBatches, + completedChunks: completedChunks, + totalChunks: totalChunks, + retrying: retrying, + attempt: attempt, + }) + mu.Unlock() + } + + ctx := context.Background() + _, err = e.EmbedBatches(ctx, batches, progress) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if len(progressCalls) != 3 { + t.Errorf("expected 3 progress calls, got %d", len(progressCalls)) + } + + for _, call := range progressCalls { + if call.totalBatches != 3 { + t.Errorf("expected totalBatches=3, got %d", call.totalBatches) + } + if call.retrying { + t.Error("unexpected retry flag") + } + } +} + +func TestVoyageAIEmbedder_WithParallelism(t *testing.T) { + t.Setenv("VOYAGE_API_KEY", "test-key") + + tests := []struct { + name string + parallelism int + expected int + }{ + {"default", 0, defaultParallelism}, + {"explicit 1", 1, 1}, + {"explicit 2", 2, 2}, + {"explicit 8", 8, 8}, + {"negative ignored", -1, defaultParallelism}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var e *VoyageAIEmbedder + var err error + if tt.parallelism == 0 { + e, err = NewVoyageAIEmbedder() + } else { + e, err = NewVoyageAIEmbedder(WithVoyageAIParallelism(tt.parallelism)) + } + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + if e.parallelism != tt.expected { + t.Errorf("expected parallelism %d, got %d", tt.expected, e.parallelism) + } + }) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_RetryOn5xx(t *testing.T) { + var requestCount int32 + serverErrorUntil := int32(2) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&requestCount, 1) + + if count <= serverErrorUntil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]string{ + "message": "Service temporarily unavailable", + "type": "server_error", + }, + }) + return + } + + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + fastRetryPolicy := RetryPolicy{ + BaseDelay: 10 * time.Millisecond, + Multiplier: 2.0, + MaxDelay: 100 * time.Millisecond, + MaxAttempts: 5, + } + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(1), + WithVoyageAIRetryPolicy(fastRetryPolicy), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + var retryCount int32 + progress := func(batchIndex, totalBatches, completedChunks, totalChunks int, retrying bool, attempt int, statusCode int) { + if retrying { + atomic.AddInt32(&retryCount, 1) + } + } + + ctx := context.Background() + results, err := e.EmbedBatches(ctx, batches, progress) + if err != nil { + t.Fatalf("EmbedBatches failed: %v", err) + } + + if len(results) != 1 { + t.Errorf("expected 1 result, got %d", len(results)) + } + + if atomic.LoadInt32(&retryCount) != 2 { + t.Errorf("expected 2 retries, got %d", retryCount) + } + + if atomic.LoadInt32(&requestCount) != 3 { + t.Errorf("expected 3 requests, got %d", requestCount) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_MaxRetryLimit(t *testing.T) { + var requestCount int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]string{ + "message": "Rate limit exceeded", + "type": "rate_limit_error", + }, + }) + })) + defer server.Close() + + fastRetryPolicy := RetryPolicy{ + BaseDelay: 5 * time.Millisecond, + Multiplier: 2.0, + MaxDelay: 50 * time.Millisecond, + MaxAttempts: 3, + } + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(1), + WithVoyageAIRetryPolicy(fastRetryPolicy), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "test"}}, + }, + } + + ctx := context.Background() + _, err = e.EmbedBatches(ctx, batches, nil) + if err == nil { + t.Fatal("expected error after max retries") + } + + expectedRequests := int32(fastRetryPolicy.MaxAttempts + 1) + if atomic.LoadInt32(&requestCount) != expectedRequests { + t.Errorf("expected %d requests (1 initial + %d retries), got %d", + expectedRequests, fastRetryPolicy.MaxAttempts, requestCount) + } + + if !strings.Contains(err.Error(), "batch 0 failed") { + t.Errorf("expected error to mention batch failure, got: %v", err) + } +} + +func TestVoyageAIEmbedder_EmbedBatches_ParallelBatchFailure(t *testing.T) { + var requestCount int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&requestCount, 1) + + var req voyageAIEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + for _, input := range req.Input { + if strings.Contains(input, "batch1") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]string{ + "message": "Invalid API key", + "type": "invalid_request_error", + }, + }) + return + } + } + + if count == 1 { + time.Sleep(20 * time.Millisecond) + } + + resp := mockEmbeddingResponse(len(req.Input)) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + e, err := NewVoyageAIEmbedder( + WithVoyageAIKey("test-key"), + WithVoyageAIEndpoint(server.URL), + WithVoyageAIParallelism(2), + WithVoyageAIDimensions(3), + ) + if err != nil { + t.Fatalf("failed to create embedder: %v", err) + } + + batches := []Batch{ + { + Index: 0, + Entries: []BatchEntry{{FileIndex: 0, ChunkIndex: 0, Content: "batch0 content"}}, + }, + { + Index: 1, + Entries: []BatchEntry{{FileIndex: 1, ChunkIndex: 0, Content: "batch1 content"}}, + }, + } + + ctx := context.Background() + _, err = e.EmbedBatches(ctx, batches, nil) + if err == nil { + t.Fatal("expected error when batch fails") + } + + retryErr, ok := err.(*RetryableError) + if !ok { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status 401, got %d", retryErr.StatusCode) + } +} diff --git a/indexer/indexer.go b/indexer/indexer.go index 257afd4..1492958 100644 --- a/indexer/indexer.go +++ b/indexer/indexer.go @@ -401,7 +401,7 @@ func (idx *Indexer) indexFilesBatched( // Embed remaining (non-cached) files if len(remainingFileChunks) > 0 { - batches := embedder.FormBatches(remainingFileChunks) + batches := embedder.FormBatches(remainingFileChunks, batchEmb.BatchConfig()) results, err := batchEmb.EmbedBatches(ctx, batches, wrapBatchProgress(onProgress)) if err != nil { return filesIndexed, chunksCreated, fmt.Errorf("failed to embed batches: %w", err) diff --git a/indexer/indexer_test.go b/indexer/indexer_test.go index 5b7535f..e679816 100644 --- a/indexer/indexer_test.go +++ b/indexer/indexer_test.go @@ -482,6 +482,10 @@ func (m *mockBatchEmbedder) Close() error { return nil } +func (m *mockBatchEmbedder) BatchConfig() embedder.BatchConfig { + return embedder.DefaultBatchConfig() +} + func (m *mockBatchEmbedder) EmbedBatches(ctx context.Context, batches []embedder.Batch, progress embedder.BatchProgress) ([]embedder.BatchResult, error) { m.embedCalled = true