From 2ea7f46becb069a4d19e01fa59d6bc2f53cdb067 Mon Sep 17 00:00:00 2001 From: Liu Yuan Date: Tue, 17 Mar 2026 23:56:34 +0800 Subject: [PATCH 1/2] feat(config): support multiple API keys for failover Add api_keys field to ModelConfig to support multiple API keys with automatic failover. When multiple keys are configured, they are expanded into separate model entries with fallbacks set up for key-level failover. Example config: { "model_name": "glm-4.7", "model": "zhipu/glm-4.7", "api_keys": ["key1", "key2", "key3"] } Expands internally to: - glm-4.7 (key1) -> fallbacks: [glm-4.7__key_1, glm-4.7__key_2] - glm-4.7__key_1 (key2) - glm-4.7__key_2 (key3) Backward compatible: single api_key still works as before. --- pkg/config/config.go | 105 ++++++++++++- pkg/config/multikey_test.go | 291 ++++++++++++++++++++++++++++++++++++ 2 files changed, 393 insertions(+), 3 deletions(-) create mode 100644 pkg/config/multikey_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index 39154372b1..3ac63572b1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -597,9 +597,11 @@ type ModelConfig struct { Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6") // HTTP-based providers - APIBase string `json:"api_base,omitempty"` // API endpoint URL - APIKey string `json:"api_key"` // API authentication key - Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + APIBase string `json:"api_base,omitempty"` // API endpoint URL + APIKey string `json:"api_key"` // API authentication key (single key) + APIKeys []string `json:"api_keys,omitempty"` // API authentication keys (multiple keys for failover) + Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + Fallbacks []string `json:"fallbacks,omitempty"` // Fallback model names for failover // Special providers (CLI-based, OAuth, etc.) AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token @@ -861,6 +863,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Expand multi-key configs into separate entries for key-level failover + cfg.ModelList = ExpandMultiKeyModels(cfg.ModelList) + // Migrate legacy channel config fields to new unified structures cfg.migrateChannelConfigs() @@ -907,14 +912,25 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo // resolveAPIKeys decrypts or dereferences each api_key in models in-place. // Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt). +// Also resolves api_keys array if present. func resolveAPIKeys(models []ModelConfig, configDir string) error { cr := credential.NewResolver(configDir) for i := range models { + // Resolve single APIKey resolved, err := cr.Resolve(models[i].APIKey) if err != nil { return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err) } models[i].APIKey = resolved + + // Resolve APIKeys array + for j, key := range models[i].APIKeys { + resolved, err := cr.Resolve(key) + if err != nil { + return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err) + } + models[i].APIKeys[j] = resolved + } } return nil } @@ -1085,6 +1101,89 @@ func MergeAPIKeys(apiKey string, apiKeys []string) []string { return all } +// ExpandMultiKeyModels expands ModelConfig entries with multiple API keys into +// separate entries for key-level failover. Each key gets its own ModelConfig entry, +// and the original entry's fallbacks are set up to chain through the expanded entries. +// +// Example: {"model_name": "gpt-4", "api_keys": ["k1", "k2", "k3"]} +// Becomes: +// - {"model_name": "gpt-4", "api_key": "k1", "fallbacks": ["gpt-4__key_1", "gpt-4__key_2"]} +// - {"model_name": "gpt-4__key_1", "api_key": "k2"} +// - {"model_name": "gpt-4__key_2", "api_key": "k3"} +func ExpandMultiKeyModels(models []ModelConfig) []ModelConfig { + var expanded []ModelConfig + + for _, m := range models { + keys := MergeAPIKeys(m.APIKey, m.APIKeys) + + // Single key or no keys: keep as-is + if len(keys) <= 1 { + // Ensure APIKey is set from APIKeys if needed + if m.APIKey == "" && len(keys) == 1 { + m.APIKey = keys[0] + } + m.APIKeys = nil // Clear APIKeys to avoid confusion + expanded = append(expanded, m) + continue + } + + // Multiple keys: expand + originalName := m.ModelName + + // Create entries for additional keys (key_1, key_2, ...) + var fallbackNames []string + for i := 1; i < len(keys); i++ { + suffix := fmt.Sprintf("__key_%d", i) + expandedName := originalName + suffix + + // Create a copy for the additional key + additionalEntry := ModelConfig{ + ModelName: expandedName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[i], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + expanded = append(expanded, additionalEntry) + fallbackNames = append(fallbackNames, expandedName) + } + + // Create the primary entry with first key and fallbacks + primaryEntry := ModelConfig{ + ModelName: originalName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[0], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + + // Prepend new fallbacks to existing ones + if len(fallbackNames) > 0 { + primaryEntry.Fallbacks = append(fallbackNames, m.Fallbacks...) + } else if len(m.Fallbacks) > 0 { + primaryEntry.Fallbacks = m.Fallbacks + } + + expanded = append(expanded, primaryEntry) + } + + return expanded +} + func (t *ToolsConfig) IsToolEnabled(name string) bool { switch name { case "web": diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go new file mode 100644 index 0000000000..b899b991cd --- /dev/null +++ b/pkg/config/multikey_test.go @@ -0,0 +1,291 @@ +package config + +import ( + "testing" +) + +func TestExpandMultiKeyModels_SingleKey(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "single-key", + }, + } + + result := ExpandMultiKeyModels(models) + + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } + + if result[0].APIKey != "single-key" { + t.Errorf("expected api_key 'single-key', got %q", result[0].APIKey) + } + + if len(result[0].Fallbacks) != 0 { + t.Errorf("expected no fallbacks, got %v", result[0].Fallbacks) + } +} + +func TestExpandMultiKeyModels_APIKeysOnly(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "glm-4.7", + Model: "zhipu/glm-4.7", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2", "key3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // First entry should be the primary with key1 and fallbacks + primary := result[2] // Primary is added last + if primary.ModelName != "glm-4.7" { + t.Errorf("expected primary model_name 'glm-4.7', got %q", primary.ModelName) + } + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } + if primary.Fallbacks[0] != "glm-4.7__key_1" { + t.Errorf("expected first fallback 'glm-4.7__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "glm-4.7__key_2" { + t.Errorf("expected second fallback 'glm-4.7__key_2', got %q", primary.Fallbacks[1]) + } + + // Second entry should be key2 + second := result[0] + if second.ModelName != "glm-4.7__key_1" { + t.Errorf("expected second model_name 'glm-4.7__key_1', got %q", second.ModelName) + } + if second.APIKey != "key2" { + t.Errorf("expected second api_key 'key2', got %q", second.APIKey) + } + + // Third entry should be key3 + third := result[1] + if third.ModelName != "glm-4.7__key_2" { + t.Errorf("expected third model_name 'glm-4.7__key_2', got %q", third.ModelName) + } + if third.APIKey != "key3" { + t.Errorf("expected third api_key 'key3', got %q", third.APIKey) + } +} + +func TestExpandMultiKeyModels_APIKeyAndAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key0", + APIKeys: []string{"key1", "key2"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models (key0 from APIKey + key1, key2 from APIKeys) + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // Primary should use key0 + primary := result[2] + if primary.APIKey != "key0" { + t.Errorf("expected primary api_key 'key0', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKeys: []string{"key1", "key2"}, + Fallbacks: []string{"claude-3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + primary := result[1] + // With 2 keys, we get 1 key fallback + 1 existing fallback = 2 total + if len(primary.Fallbacks) != 2 { + t.Fatalf("expected 2 fallbacks, got %d: %v", len(primary.Fallbacks), primary.Fallbacks) + } + + // Key fallbacks should come first, then existing fallbacks + if primary.Fallbacks[0] != "gpt-4__key_1" { + t.Errorf("expected first fallback 'gpt-4__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "claude-3" { + t.Errorf("expected second fallback 'claude-3', got %q", primary.Fallbacks[1]) + } +} + +func TestExpandMultiKeyModels_EmptyAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "", + APIKeys: []string{}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should keep as-is with no changes + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } +} + +func TestExpandMultiKeyModels_Deduplication(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key1", + APIKeys: []string{"key1", "key2", "key1"}, // Duplicate key1 + }, + } + + result := ExpandMultiKeyModels(models) + + // Should only create 2 models (deduplicated keys) + if len(result) != 2 { + t.Fatalf("expected 2 models (deduplicated), got %d", len(result)) + } + + primary := result[1] + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 1 { + t.Errorf("expected 1 fallback, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2"}, + Proxy: "http://proxy:8080", + RPM: 60, + MaxTokensField: "max_completion_tokens", + RequestTimeout: 30, + ThinkingLevel: "high", + }, + } + + result := ExpandMultiKeyModels(models) + + // Check primary entry preserves all fields + primary := result[1] + if primary.APIBase != "https://api.example.com" { + t.Errorf("expected api_base preserved, got %q", primary.APIBase) + } + if primary.Proxy != "http://proxy:8080" { + t.Errorf("expected proxy preserved, got %q", primary.Proxy) + } + if primary.RPM != 60 { + t.Errorf("expected rpm preserved, got %d", primary.RPM) + } + if primary.MaxTokensField != "max_completion_tokens" { + t.Errorf("expected max_tokens_field preserved, got %q", primary.MaxTokensField) + } + if primary.RequestTimeout != 30 { + t.Errorf("expected request_timeout preserved, got %d", primary.RequestTimeout) + } + if primary.ThinkingLevel != "high" { + t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel) + } + + // Check additional entry also preserves fields + additional := result[0] + if additional.APIBase != "https://api.example.com" { + t.Errorf("expected additional api_base preserved, got %q", additional.APIBase) + } + if additional.RPM != 60 { + t.Errorf("expected additional rpm preserved, got %d", additional.RPM) + } +} + +func TestMergeAPIKeys(t *testing.T) { + tests := []struct { + name string + apiKey string + apiKeys []string + expected []string + }{ + { + name: "both empty", + apiKey: "", + apiKeys: nil, + expected: nil, + }, + { + name: "only apiKey", + apiKey: "key1", + apiKeys: nil, + expected: []string{"key1"}, + }, + { + name: "only apiKeys", + apiKey: "", + apiKeys: []string{"key1", "key2"}, + expected: []string{"key1", "key2"}, + }, + { + name: "both with overlap", + apiKey: "key1", + apiKeys: []string{"key1", "key2", "key3"}, + expected: []string{"key1", "key2", "key3"}, + }, + { + name: "with whitespace", + apiKey: " key1 ", + apiKeys: []string{" key2 ", " key1 "}, + expected: []string{"key1", "key2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MergeAPIKeys(tt.apiKey, tt.apiKeys) + if len(result) != len(tt.expected) { + t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result)) + } + for i, k := range result { + if k != tt.expected[i] { + t.Errorf("expected key[%d] = %q, got %q", i, tt.expected[i], k) + } + } + }) + } +} From 38e144d2d0747997c839cc8a1e94dbfc8d8f4fbc Mon Sep 17 00:00:00 2001 From: Liu Yuan Date: Wed, 18 Mar 2026 01:53:24 +0800 Subject: [PATCH 2/2] fix(providers): change cooldown tracking from provider to ModelKey This enables proper key-switching when multiple API keys share the same provider. Previously, when one key failed, all keys were blocked because cooldown was tracked per-provider. Now each (provider, model) combination has independent cooldown, allowing fallback to alternate keys when one is rate limited. Includes TestMultiKeyWithModelFallback and related failover tests. --- pkg/providers/fallback.go | 16 +- pkg/providers/fallback_multikey_test.go | 384 ++++++++++++++++++++++++ pkg/providers/fallback_test.go | 15 +- 3 files changed, 401 insertions(+), 14 deletions(-) create mode 100644 pkg/providers/fallback_multikey_test.go diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index 7ba563b66f..549ec78378 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -117,17 +117,19 @@ func (fc *FallbackChain) Execute( return nil, context.Canceled } - // Check cooldown. - if !fc.cooldown.IsAvailable(candidate.Provider) { - remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + // Check cooldown (per provider/model, not just provider). + // This allows multi-key failover where different keys use different model names. + cooldownKey := ModelKey(candidate.Provider, candidate.Model) + if !fc.cooldown.IsAvailable(cooldownKey) { + remaining := fc.cooldown.CooldownRemaining(cooldownKey) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, Skipped: true, Reason: FailoverRateLimit, Error: fmt.Errorf( - "provider %s in cooldown (%s remaining)", - candidate.Provider, + "%s in cooldown (%s remaining)", + cooldownKey, remaining.Round(time.Second), ), }) @@ -141,7 +143,7 @@ func (fc *FallbackChain) Execute( if err == nil { // Success. - fc.cooldown.MarkSuccess(candidate.Provider) + fc.cooldown.MarkSuccess(cooldownKey) result.Response = resp result.Provider = candidate.Provider result.Model = candidate.Model @@ -187,7 +189,7 @@ func (fc *FallbackChain) Execute( } // Retriable error: mark failure and continue to next candidate. - fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + fc.cooldown.MarkFailure(cooldownKey, failErr.Reason) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, diff --git a/pkg/providers/fallback_multikey_test.go b/pkg/providers/fallback_multikey_test.go new file mode 100644 index 0000000000..9ed8fa73cb --- /dev/null +++ b/pkg/providers/fallback_multikey_test.go @@ -0,0 +1,384 @@ +package providers + +import ( + "context" + "errors" + "testing" +) + +// TestMultiKeyFailover tests the complete failover flow with multiple API keys. +// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"] +// is expanded into primary + fallbacks. +func TestMultiKeyFailover(t *testing.T) { + // Simulate expanded config: primary with 2 fallbacks + // This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"] + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Create fallback chain + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with 429, second succeeds + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + if callCount == 1 { + // First call: simulate rate limit + return nil, errors.New("http error: status 429 - rate limit exceeded") + } + // Second call: success + return &LLMResponse{ + Content: "Hello from key2!", + }, nil + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover, got error: %v", err) + } + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.Response.Content != "Hello from key2!" { + t.Errorf("expected response from key2, got: %s", result.Response.Content) + } + + if callCount != 2 { + t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount) + } + + // Verify first attempt was recorded + if len(result.Attempts) != 1 { + t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts)) + } + + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf( + "expected first attempt reason to be rate_limit, got: %s", + result.Attempts[0].Reason, + ) + } +} + +// TestMultiKeyFailoverAllFail tests when all keys hit rate limit +func TestMultiKeyFailoverAllFail(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: all calls fail with rate limit + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("status: 429 - too many requests") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error when all keys fail, got nil") + } + + if result != nil { + t.Errorf("expected nil result on failure, got: %v", result) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (all fail), got %d", callCount) + } + + // Verify error type + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err) + } + + if len(exhausted.Attempts) != 3 { + t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts)) + } +} + +// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped +func TestMultiKeyFailoverCooldown(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Put the first model in cooldown (using ModelKey now, not just provider) + cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model) + cooldown.MarkFailure(cooldownKey, FailoverRateLimit) + + // Verify it's not available + if cooldown.IsAvailable(cooldownKey) { + t.Fatal("expected first model to be in cooldown") + } + + // Mock run function: only second should be called + callCount := 0 + calledProviders := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledProviders = append(calledProviders, provider+"/"+model) + return &LLMResponse{Content: "success"}, nil + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + + // First provider should have been skipped + if callCount != 1 { + t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount) + } + + // Should have called the second provider/model + if len(calledProviders) != 1 || + calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model { + t.Errorf("expected second model to be called, got: %v", calledProviders) + } + + // Verify first attempt was recorded as skipped + if len(result.Attempts) != 1 { + t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts)) + } + + if !result.Attempts[0].Skipped { + t.Error("expected first attempt to be marked as skipped") + } +} + +// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable +func TestMultiKeyFailoverWithFormatError(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with format error (bad request) + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("invalid request format: tool_use.id missing") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error for format failure, got nil") + } + + // Format errors should NOT trigger failover (non-retriable) + // So we should only have 1 call + if callCount != 1 { + t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount) + } + + // Verify the error is a FailoverError with format reason + var failoverErr *FailoverError + if !errors.As(err, &failoverErr) { + t.Errorf("expected FailoverError, got: %T - %v", err, err) + } + + if failoverErr.Reason != FailoverFormat { + t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason) + } + + _ = result // result should be nil +} + +// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback. +// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"] +// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax +func TestMultiKeyWithModelFallback(t *testing.T) { + // Simulate expanded config from: + // { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] } + // After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"] + // Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax" + // In this test, we use the full format to avoid needing a lookup function. + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + // Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax) + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Verify candidate order + if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" { + t.Errorf( + "expected first candidate to be zhipu/glm-4.7, got: %s/%s", + candidates[0].Provider, + candidates[0].Model, + ) + } + if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" { + t.Errorf( + "expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s", + candidates[1].Provider, + candidates[1].Model, + ) + } + if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" { + t.Errorf( + "expected third candidate to be minimax/minimax, got: %s/%s", + candidates[2].Provider, + candidates[2].Model, + ) + } + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first two fail, third succeeds (model fallback) + callCount := 0 + calledModels := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledModels = append(calledModels, provider+"/"+model) + + switch callCount { + case 1: + // k1: rate limit + return nil, errors.New("status: 429 - rate limit") + case 2: + // k2: also rate limit (all zhipu keys exhausted) + return nil, errors.New("status: 429 - rate limit") + case 3: + // minimax: success + return &LLMResponse{Content: "success from minimax"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover to model fallback, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount) + } + + if result.Response.Content != "success from minimax" { + t.Errorf("expected response from minimax, got: %s", result.Response.Content) + } + + // Verify call order + if len(calledModels) != 3 { + t.Fatalf("expected 3 called models, got %d", len(calledModels)) + } + if calledModels[0] != "zhipu/glm-4.7" { + t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0]) + } + if calledModels[1] != "zhipu/glm-4.7__key_1" { + t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1]) + } + if calledModels[2] != "minimax/minimax" { + t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2]) + } + + // Verify 2 failed attempts recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // Both should be rate limit + for i, attempt := range result.Attempts { + if attempt.Reason != FailoverRateLimit { + t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason) + } + } +} + +// TestMultiKeyFailoverMixedErrors tests failover with different error types +func TestMultiKeyFailoverMixedErrors(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: different errors for each key + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + switch callCount { + case 1: + // First: rate limit (retriable) + return nil, errors.New("status: 429 - rate limit") + case 2: + // Second: timeout (retriable) + return nil, errors.New("context deadline exceeded") + case 3: + // Third: success + return &LLMResponse{Content: "success from key3"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after 2 failovers, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls, got %d", callCount) + } + + // Verify both failed attempts were recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // First should be rate limit + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason) + } + + // Second should be timeout + if result.Attempts[1].Reason != FailoverTimeout { + t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason) + } +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 1783ebcb5f..1a1118e336 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -157,8 +157,8 @@ func TestFallback_CooldownSkip(t *testing.T) { ct, _ := newTestTracker(now) fc := NewFallbackChain(ct) - // Put openai in cooldown - ct.MarkFailure("openai", FailoverRateLimit) + // Put openai/gpt-4 in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -195,9 +195,9 @@ func TestFallback_AllInCooldown(t *testing.T) { ct := NewCooldownTracker() fc := NewFallbackChain(ct) - // Put all providers in cooldown - ct.MarkFailure("openai", FailoverRateLimit) - ct.MarkFailure("anthropic", FailoverBilling) + // Put all models in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) + ct.MarkFailure(ModelKey("anthropic", "claude"), FailoverBilling) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -273,12 +273,13 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { fc := NewFallbackChain(ct) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + modelKey := ModelKey("openai", "gpt-4") attempt := 0 run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { attempt++ if attempt == 1 { - ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + ct.MarkFailure(modelKey, FailoverRateLimit) // simulate failure tracked elsewhere } return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil } @@ -287,7 +288,7 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !ct.IsAvailable("openai") { + if !ct.IsAvailable(modelKey) { t.Error("success should reset cooldown") } }