feat(config): support multiple API keys for failover (#1707)

* feat(config): support multiple API keys for failover

Add api_keys field to ModelConfig to support multiple API keys with
automatic failover. When multiple keys are configured, they are expanded
into separate model entries with fallbacks set up for key-level failover.

Example config:
  {
    "model_name": "glm-4.7",
    "model": "zhipu/glm-4.7",
    "api_keys": ["key1", "key2", "key3"]
  }

Expands internally to:
  - glm-4.7 (key1) -> fallbacks: [glm-4.7__key_1, glm-4.7__key_2]
  - glm-4.7__key_1 (key2)
  - glm-4.7__key_2 (key3)

Backward compatible: single api_key still works as before.

* fix(providers): change cooldown tracking from provider to ModelKey

This enables proper key-switching when multiple API keys share the same
provider. Previously, when one key failed, all keys were blocked because
cooldown was tracked per-provider.

Now each (provider, model) combination has independent cooldown, allowing
fallback to alternate keys when one is rate limited.

Includes TestMultiKeyWithModelFallback and related failover tests.
This commit is contained in:
Liu Yuan
2026-03-19 00:57:20 +08:00
committed by GitHub
parent 08f305d712
commit e73d9d959e
5 changed files with 794 additions and 17 deletions
+102 -3
View File
@@ -603,9 +603,11 @@ type ModelConfig struct {
Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6")
// HTTP-based providers
APIBase string `json:"api_base,omitempty"` // API endpoint URL
APIKey string `json:"api_key"` // API authentication key
Proxy string `json:"proxy,omitempty"` // HTTP proxy URL
APIBase string `json:"api_base,omitempty"` // API endpoint URL
APIKey string `json:"api_key"` // API authentication key (single key)
APIKeys []string `json:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
Proxy string `json:"proxy,omitempty"` // HTTP proxy URL
Fallbacks []string `json:"fallbacks,omitempty"` // Fallback model names for failover
// Special providers (CLI-based, OAuth, etc.)
AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token
@@ -874,6 +876,9 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
// Expand multi-key configs into separate entries for key-level failover
cfg.ModelList = ExpandMultiKeyModels(cfg.ModelList)
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -920,14 +925,25 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
// Also resolves api_keys array if present.
func resolveAPIKeys(models []ModelConfig, configDir string) error {
cr := credential.NewResolver(configDir)
for i := range models {
// Resolve single APIKey
resolved, err := cr.Resolve(models[i].APIKey)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
}
models[i].APIKey = resolved
// Resolve APIKeys array
for j, key := range models[i].APIKeys {
resolved, err := cr.Resolve(key)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err)
}
models[i].APIKeys[j] = resolved
}
}
return nil
}
@@ -1098,6 +1114,89 @@ func MergeAPIKeys(apiKey string, apiKeys []string) []string {
return all
}
// ExpandMultiKeyModels expands ModelConfig entries with multiple API keys into
// separate entries for key-level failover. Each key gets its own ModelConfig entry,
// and the original entry's fallbacks are set up to chain through the expanded entries.
//
// Example: {"model_name": "gpt-4", "api_keys": ["k1", "k2", "k3"]}
// Becomes:
// - {"model_name": "gpt-4", "api_key": "k1", "fallbacks": ["gpt-4__key_1", "gpt-4__key_2"]}
// - {"model_name": "gpt-4__key_1", "api_key": "k2"}
// - {"model_name": "gpt-4__key_2", "api_key": "k3"}
func ExpandMultiKeyModels(models []ModelConfig) []ModelConfig {
var expanded []ModelConfig
for _, m := range models {
keys := MergeAPIKeys(m.APIKey, m.APIKeys)
// Single key or no keys: keep as-is
if len(keys) <= 1 {
// Ensure APIKey is set from APIKeys if needed
if m.APIKey == "" && len(keys) == 1 {
m.APIKey = keys[0]
}
m.APIKeys = nil // Clear APIKeys to avoid confusion
expanded = append(expanded, m)
continue
}
// Multiple keys: expand
originalName := m.ModelName
// Create entries for additional keys (key_1, key_2, ...)
var fallbackNames []string
for i := 1; i < len(keys); i++ {
suffix := fmt.Sprintf("__key_%d", i)
expandedName := originalName + suffix
// Create a copy for the additional key
additionalEntry := ModelConfig{
ModelName: expandedName,
Model: m.Model,
APIBase: m.APIBase,
APIKey: keys[i],
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
}
expanded = append(expanded, additionalEntry)
fallbackNames = append(fallbackNames, expandedName)
}
// Create the primary entry with first key and fallbacks
primaryEntry := ModelConfig{
ModelName: originalName,
Model: m.Model,
APIBase: m.APIBase,
APIKey: keys[0],
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
}
// Prepend new fallbacks to existing ones
if len(fallbackNames) > 0 {
primaryEntry.Fallbacks = append(fallbackNames, m.Fallbacks...)
} else if len(m.Fallbacks) > 0 {
primaryEntry.Fallbacks = m.Fallbacks
}
expanded = append(expanded, primaryEntry)
}
return expanded
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
+291
View File
@@ -0,0 +1,291 @@
package config
import (
"testing"
)
func TestExpandMultiKeyModels_SingleKey(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIKey: "single-key",
},
}
result := ExpandMultiKeyModels(models)
if len(result) != 1 {
t.Fatalf("expected 1 model, got %d", len(result))
}
if result[0].ModelName != "gpt-4" {
t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName)
}
if result[0].APIKey != "single-key" {
t.Errorf("expected api_key 'single-key', got %q", result[0].APIKey)
}
if len(result[0].Fallbacks) != 0 {
t.Errorf("expected no fallbacks, got %v", result[0].Fallbacks)
}
}
func TestExpandMultiKeyModels_APIKeysOnly(t *testing.T) {
models := []ModelConfig{
{
ModelName: "glm-4.7",
Model: "zhipu/glm-4.7",
APIBase: "https://api.example.com",
APIKeys: []string{"key1", "key2", "key3"},
},
}
result := ExpandMultiKeyModels(models)
// Should expand to 3 models
if len(result) != 3 {
t.Fatalf("expected 3 models, got %d", len(result))
}
// First entry should be the primary with key1 and fallbacks
primary := result[2] // Primary is added last
if primary.ModelName != "glm-4.7" {
t.Errorf("expected primary model_name 'glm-4.7', got %q", primary.ModelName)
}
if primary.APIKey != "key1" {
t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey)
}
if len(primary.Fallbacks) != 2 {
t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks))
}
if primary.Fallbacks[0] != "glm-4.7__key_1" {
t.Errorf("expected first fallback 'glm-4.7__key_1', got %q", primary.Fallbacks[0])
}
if primary.Fallbacks[1] != "glm-4.7__key_2" {
t.Errorf("expected second fallback 'glm-4.7__key_2', got %q", primary.Fallbacks[1])
}
// Second entry should be key2
second := result[0]
if second.ModelName != "glm-4.7__key_1" {
t.Errorf("expected second model_name 'glm-4.7__key_1', got %q", second.ModelName)
}
if second.APIKey != "key2" {
t.Errorf("expected second api_key 'key2', got %q", second.APIKey)
}
// Third entry should be key3
third := result[1]
if third.ModelName != "glm-4.7__key_2" {
t.Errorf("expected third model_name 'glm-4.7__key_2', got %q", third.ModelName)
}
if third.APIKey != "key3" {
t.Errorf("expected third api_key 'key3', got %q", third.APIKey)
}
}
func TestExpandMultiKeyModels_APIKeyAndAPIKeys(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIKey: "key0",
APIKeys: []string{"key1", "key2"},
},
}
result := ExpandMultiKeyModels(models)
// Should expand to 3 models (key0 from APIKey + key1, key2 from APIKeys)
if len(result) != 3 {
t.Fatalf("expected 3 models, got %d", len(result))
}
// Primary should use key0
primary := result[2]
if primary.APIKey != "key0" {
t.Errorf("expected primary api_key 'key0', got %q", primary.APIKey)
}
if len(primary.Fallbacks) != 2 {
t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks))
}
}
func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIKeys: []string{"key1", "key2"},
Fallbacks: []string{"claude-3"},
},
}
result := ExpandMultiKeyModels(models)
primary := result[1]
// With 2 keys, we get 1 key fallback + 1 existing fallback = 2 total
if len(primary.Fallbacks) != 2 {
t.Fatalf("expected 2 fallbacks, got %d: %v", len(primary.Fallbacks), primary.Fallbacks)
}
// Key fallbacks should come first, then existing fallbacks
if primary.Fallbacks[0] != "gpt-4__key_1" {
t.Errorf("expected first fallback 'gpt-4__key_1', got %q", primary.Fallbacks[0])
}
if primary.Fallbacks[1] != "claude-3" {
t.Errorf("expected second fallback 'claude-3', got %q", primary.Fallbacks[1])
}
}
func TestExpandMultiKeyModels_EmptyAPIKeys(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIKey: "",
APIKeys: []string{},
},
}
result := ExpandMultiKeyModels(models)
// Should keep as-is with no changes
if len(result) != 1 {
t.Fatalf("expected 1 model, got %d", len(result))
}
if result[0].ModelName != "gpt-4" {
t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName)
}
}
func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIKey: "key1",
APIKeys: []string{"key1", "key2", "key1"}, // Duplicate key1
},
}
result := ExpandMultiKeyModels(models)
// Should only create 2 models (deduplicated keys)
if len(result) != 2 {
t.Fatalf("expected 2 models (deduplicated), got %d", len(result))
}
primary := result[1]
if primary.APIKey != "key1" {
t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey)
}
if len(primary.Fallbacks) != 1 {
t.Errorf("expected 1 fallback, got %d", len(primary.Fallbacks))
}
}
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
models := []ModelConfig{
{
ModelName: "gpt-4",
Model: "openai/gpt-4o",
APIBase: "https://api.example.com",
APIKeys: []string{"key1", "key2"},
Proxy: "http://proxy:8080",
RPM: 60,
MaxTokensField: "max_completion_tokens",
RequestTimeout: 30,
ThinkingLevel: "high",
},
}
result := ExpandMultiKeyModels(models)
// Check primary entry preserves all fields
primary := result[1]
if primary.APIBase != "https://api.example.com" {
t.Errorf("expected api_base preserved, got %q", primary.APIBase)
}
if primary.Proxy != "http://proxy:8080" {
t.Errorf("expected proxy preserved, got %q", primary.Proxy)
}
if primary.RPM != 60 {
t.Errorf("expected rpm preserved, got %d", primary.RPM)
}
if primary.MaxTokensField != "max_completion_tokens" {
t.Errorf("expected max_tokens_field preserved, got %q", primary.MaxTokensField)
}
if primary.RequestTimeout != 30 {
t.Errorf("expected request_timeout preserved, got %d", primary.RequestTimeout)
}
if primary.ThinkingLevel != "high" {
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
}
// Check additional entry also preserves fields
additional := result[0]
if additional.APIBase != "https://api.example.com" {
t.Errorf("expected additional api_base preserved, got %q", additional.APIBase)
}
if additional.RPM != 60 {
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
}
}
func TestMergeAPIKeys(t *testing.T) {
tests := []struct {
name string
apiKey string
apiKeys []string
expected []string
}{
{
name: "both empty",
apiKey: "",
apiKeys: nil,
expected: nil,
},
{
name: "only apiKey",
apiKey: "key1",
apiKeys: nil,
expected: []string{"key1"},
},
{
name: "only apiKeys",
apiKey: "",
apiKeys: []string{"key1", "key2"},
expected: []string{"key1", "key2"},
},
{
name: "both with overlap",
apiKey: "key1",
apiKeys: []string{"key1", "key2", "key3"},
expected: []string{"key1", "key2", "key3"},
},
{
name: "with whitespace",
apiKey: " key1 ",
apiKeys: []string{" key2 ", " key1 "},
expected: []string{"key1", "key2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MergeAPIKeys(tt.apiKey, tt.apiKeys)
if len(result) != len(tt.expected) {
t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result))
}
for i, k := range result {
if k != tt.expected[i] {
t.Errorf("expected key[%d] = %q, got %q", i, tt.expected[i], k)
}
}
})
}
}
+9 -7
View File
@@ -117,17 +117,19 @@ func (fc *FallbackChain) Execute(
return nil, context.Canceled
}
// Check cooldown.
if !fc.cooldown.IsAvailable(candidate.Provider) {
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
// Check cooldown (per provider/model, not just provider).
// This allows multi-key failover where different keys use different model names.
cooldownKey := ModelKey(candidate.Provider, candidate.Model)
if !fc.cooldown.IsAvailable(cooldownKey) {
remaining := fc.cooldown.CooldownRemaining(cooldownKey)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Skipped: true,
Reason: FailoverRateLimit,
Error: fmt.Errorf(
"provider %s in cooldown (%s remaining)",
candidate.Provider,
"%s in cooldown (%s remaining)",
cooldownKey,
remaining.Round(time.Second),
),
})
@@ -141,7 +143,7 @@ func (fc *FallbackChain) Execute(
if err == nil {
// Success.
fc.cooldown.MarkSuccess(candidate.Provider)
fc.cooldown.MarkSuccess(cooldownKey)
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
@@ -187,7 +189,7 @@ func (fc *FallbackChain) Execute(
}
// Retriable error: mark failure and continue to next candidate.
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
fc.cooldown.MarkFailure(cooldownKey, failErr.Reason)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
+384
View File
@@ -0,0 +1,384 @@
package providers
import (
"context"
"errors"
"testing"
)
// TestMultiKeyFailover tests the complete failover flow with multiple API keys.
// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"]
// is expanded into primary + fallbacks.
func TestMultiKeyFailover(t *testing.T) {
// Simulate expanded config: primary with 2 fallbacks
// This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"]
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
}
candidates := ResolveCandidates(cfg, "zhipu")
if len(candidates) != 3 {
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
}
// Create fallback chain
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Mock run function: first call fails with 429, second succeeds
callCount := 0
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
if callCount == 1 {
// First call: simulate rate limit
return nil, errors.New("http error: status 429 - rate limit exceeded")
}
// Second call: success
return &LLMResponse{
Content: "Hello from key2!",
}, nil
}
// Execute fallback chain
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err != nil {
t.Fatalf("expected success after failover, got error: %v", err)
}
if result == nil {
t.Fatal("expected result, got nil")
}
if result.Response.Content != "Hello from key2!" {
t.Errorf("expected response from key2, got: %s", result.Response.Content)
}
if callCount != 2 {
t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount)
}
// Verify first attempt was recorded
if len(result.Attempts) != 1 {
t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts))
}
if result.Attempts[0].Reason != FailoverRateLimit {
t.Errorf(
"expected first attempt reason to be rate_limit, got: %s",
result.Attempts[0].Reason,
)
}
}
// TestMultiKeyFailoverAllFail tests when all keys hit rate limit
func TestMultiKeyFailoverAllFail(t *testing.T) {
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
}
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Mock run function: all calls fail with rate limit
callCount := 0
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
return nil, errors.New("status: 429 - too many requests")
}
// Execute fallback chain
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err == nil {
t.Fatal("expected error when all keys fail, got nil")
}
if result != nil {
t.Errorf("expected nil result on failure, got: %v", result)
}
if callCount != 3 {
t.Errorf("expected 3 calls (all fail), got %d", callCount)
}
// Verify error type
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err)
}
if len(exhausted.Attempts) != 3 {
t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts))
}
}
// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped
func TestMultiKeyFailoverCooldown(t *testing.T) {
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1"},
}
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Put the first model in cooldown (using ModelKey now, not just provider)
cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model)
cooldown.MarkFailure(cooldownKey, FailoverRateLimit)
// Verify it's not available
if cooldown.IsAvailable(cooldownKey) {
t.Fatal("expected first model to be in cooldown")
}
// Mock run function: only second should be called
callCount := 0
calledProviders := []string{}
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
calledProviders = append(calledProviders, provider+"/"+model)
return &LLMResponse{Content: "success"}, nil
}
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
// First provider should have been skipped
if callCount != 1 {
t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount)
}
// Should have called the second provider/model
if len(calledProviders) != 1 ||
calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model {
t.Errorf("expected second model to be called, got: %v", calledProviders)
}
// Verify first attempt was recorded as skipped
if len(result.Attempts) != 1 {
t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts))
}
if !result.Attempts[0].Skipped {
t.Error("expected first attempt to be marked as skipped")
}
}
// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable
func TestMultiKeyFailoverWithFormatError(t *testing.T) {
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1"},
}
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Mock run function: first call fails with format error (bad request)
callCount := 0
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
return nil, errors.New("invalid request format: tool_use.id missing")
}
// Execute fallback chain
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err == nil {
t.Fatal("expected error for format failure, got nil")
}
// Format errors should NOT trigger failover (non-retriable)
// So we should only have 1 call
if callCount != 1 {
t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount)
}
// Verify the error is a FailoverError with format reason
var failoverErr *FailoverError
if !errors.As(err, &failoverErr) {
t.Errorf("expected FailoverError, got: %T - %v", err, err)
}
if failoverErr.Reason != FailoverFormat {
t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason)
}
_ = result // result should be nil
}
// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback.
// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"]
// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax
func TestMultiKeyWithModelFallback(t *testing.T) {
// Simulate expanded config from:
// { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] }
// After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"]
// Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax"
// In this test, we use the full format to avoid needing a lookup function.
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"},
}
candidates := ResolveCandidates(cfg, "zhipu")
// Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax)
if len(candidates) != 3 {
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
}
// Verify candidate order
if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" {
t.Errorf(
"expected first candidate to be zhipu/glm-4.7, got: %s/%s",
candidates[0].Provider,
candidates[0].Model,
)
}
if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" {
t.Errorf(
"expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s",
candidates[1].Provider,
candidates[1].Model,
)
}
if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" {
t.Errorf(
"expected third candidate to be minimax/minimax, got: %s/%s",
candidates[2].Provider,
candidates[2].Model,
)
}
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Mock run function: first two fail, third succeeds (model fallback)
callCount := 0
calledModels := []string{}
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
calledModels = append(calledModels, provider+"/"+model)
switch callCount {
case 1:
// k1: rate limit
return nil, errors.New("status: 429 - rate limit")
case 2:
// k2: also rate limit (all zhipu keys exhausted)
return nil, errors.New("status: 429 - rate limit")
case 3:
// minimax: success
return &LLMResponse{Content: "success from minimax"}, nil
default:
return nil, errors.New("unexpected call")
}
}
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err != nil {
t.Fatalf("expected success after failover to model fallback, got error: %v", err)
}
if callCount != 3 {
t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount)
}
if result.Response.Content != "success from minimax" {
t.Errorf("expected response from minimax, got: %s", result.Response.Content)
}
// Verify call order
if len(calledModels) != 3 {
t.Fatalf("expected 3 called models, got %d", len(calledModels))
}
if calledModels[0] != "zhipu/glm-4.7" {
t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0])
}
if calledModels[1] != "zhipu/glm-4.7__key_1" {
t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1])
}
if calledModels[2] != "minimax/minimax" {
t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2])
}
// Verify 2 failed attempts recorded
if len(result.Attempts) != 2 {
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
}
// Both should be rate limit
for i, attempt := range result.Attempts {
if attempt.Reason != FailoverRateLimit {
t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason)
}
}
}
// TestMultiKeyFailoverMixedErrors tests failover with different error types
func TestMultiKeyFailoverMixedErrors(t *testing.T) {
cfg := ModelConfig{
Primary: "glm-4.7",
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
}
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
chain := NewFallbackChain(cooldown)
// Mock run function: different errors for each key
callCount := 0
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
callCount++
switch callCount {
case 1:
// First: rate limit (retriable)
return nil, errors.New("status: 429 - rate limit")
case 2:
// Second: timeout (retriable)
return nil, errors.New("context deadline exceeded")
case 3:
// Third: success
return &LLMResponse{Content: "success from key3"}, nil
default:
return nil, errors.New("unexpected call")
}
}
result, err := chain.Execute(context.Background(), candidates, mockRun)
if err != nil {
t.Fatalf("expected success after 2 failovers, got error: %v", err)
}
if callCount != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
// Verify both failed attempts were recorded
if len(result.Attempts) != 2 {
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
}
// First should be rate limit
if result.Attempts[0].Reason != FailoverRateLimit {
t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason)
}
// Second should be timeout
if result.Attempts[1].Reason != FailoverTimeout {
t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason)
}
}
+8 -7
View File
@@ -157,8 +157,8 @@ func TestFallback_CooldownSkip(t *testing.T) {
ct, _ := newTestTracker(now)
fc := NewFallbackChain(ct)
// Put openai in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
// Put openai/gpt-4 in cooldown (using ModelKey now)
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -195,9 +195,9 @@ func TestFallback_AllInCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
// Put all providers in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
// Put all models in cooldown (using ModelKey now)
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
ct.MarkFailure(ModelKey("anthropic", "claude"), FailoverBilling)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -273,12 +273,13 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
modelKey := ModelKey("openai", "gpt-4")
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
ct.MarkFailure(modelKey, FailoverRateLimit) // simulate failure tracked elsewhere
}
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
@@ -287,7 +288,7 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ct.IsAvailable("openai") {
if !ct.IsAvailable(modelKey) {
t.Error("success should reset cooldown")
}
}