mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
e73d9d959e
* feat(config): support multiple API keys for failover
Add api_keys field to ModelConfig to support multiple API keys with
automatic failover. When multiple keys are configured, they are expanded
into separate model entries with fallbacks set up for key-level failover.
Example config:
{
"model_name": "glm-4.7",
"model": "zhipu/glm-4.7",
"api_keys": ["key1", "key2", "key3"]
}
Expands internally to:
- glm-4.7 (key1) -> fallbacks: [glm-4.7__key_1, glm-4.7__key_2]
- glm-4.7__key_1 (key2)
- glm-4.7__key_2 (key3)
Backward compatible: single api_key still works as before.
* fix(providers): change cooldown tracking from provider to ModelKey
This enables proper key-switching when multiple API keys share the same
provider. Previously, when one key failed, all keys were blocked because
cooldown was tracked per-provider.
Now each (provider, model) combination has independent cooldown, allowing
fallback to alternate keys when one is rate limited.
Includes TestMultiKeyWithModelFallback and related failover tests.
385 lines
11 KiB
Go
385 lines
11 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
)
|
|
|
|
// TestMultiKeyFailover tests the complete failover flow with multiple API keys.
|
|
// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"]
|
|
// is expanded into primary + fallbacks.
|
|
func TestMultiKeyFailover(t *testing.T) {
|
|
// Simulate expanded config: primary with 2 fallbacks
|
|
// This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"]
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
if len(candidates) != 3 {
|
|
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
|
|
}
|
|
|
|
// Create fallback chain
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Mock run function: first call fails with 429, second succeeds
|
|
callCount := 0
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
if callCount == 1 {
|
|
// First call: simulate rate limit
|
|
return nil, errors.New("http error: status 429 - rate limit exceeded")
|
|
}
|
|
// Second call: success
|
|
return &LLMResponse{
|
|
Content: "Hello from key2!",
|
|
}, nil
|
|
}
|
|
|
|
// Execute fallback chain
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
if err != nil {
|
|
t.Fatalf("expected success after failover, got error: %v", err)
|
|
}
|
|
|
|
if result == nil {
|
|
t.Fatal("expected result, got nil")
|
|
}
|
|
|
|
if result.Response.Content != "Hello from key2!" {
|
|
t.Errorf("expected response from key2, got: %s", result.Response.Content)
|
|
}
|
|
|
|
if callCount != 2 {
|
|
t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount)
|
|
}
|
|
|
|
// Verify first attempt was recorded
|
|
if len(result.Attempts) != 1 {
|
|
t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts))
|
|
}
|
|
|
|
if result.Attempts[0].Reason != FailoverRateLimit {
|
|
t.Errorf(
|
|
"expected first attempt reason to be rate_limit, got: %s",
|
|
result.Attempts[0].Reason,
|
|
)
|
|
}
|
|
}
|
|
|
|
// TestMultiKeyFailoverAllFail tests when all keys hit rate limit
|
|
func TestMultiKeyFailoverAllFail(t *testing.T) {
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Mock run function: all calls fail with rate limit
|
|
callCount := 0
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
return nil, errors.New("status: 429 - too many requests")
|
|
}
|
|
|
|
// Execute fallback chain
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error when all keys fail, got nil")
|
|
}
|
|
|
|
if result != nil {
|
|
t.Errorf("expected nil result on failure, got: %v", result)
|
|
}
|
|
|
|
if callCount != 3 {
|
|
t.Errorf("expected 3 calls (all fail), got %d", callCount)
|
|
}
|
|
|
|
// Verify error type
|
|
var exhausted *FallbackExhaustedError
|
|
if !errors.As(err, &exhausted) {
|
|
t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err)
|
|
}
|
|
|
|
if len(exhausted.Attempts) != 3 {
|
|
t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts))
|
|
}
|
|
}
|
|
|
|
// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped
|
|
func TestMultiKeyFailoverCooldown(t *testing.T) {
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Put the first model in cooldown (using ModelKey now, not just provider)
|
|
cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model)
|
|
cooldown.MarkFailure(cooldownKey, FailoverRateLimit)
|
|
|
|
// Verify it's not available
|
|
if cooldown.IsAvailable(cooldownKey) {
|
|
t.Fatal("expected first model to be in cooldown")
|
|
}
|
|
|
|
// Mock run function: only second should be called
|
|
callCount := 0
|
|
calledProviders := []string{}
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
calledProviders = append(calledProviders, provider+"/"+model)
|
|
return &LLMResponse{Content: "success"}, nil
|
|
}
|
|
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
if err != nil {
|
|
t.Fatalf("expected success, got error: %v", err)
|
|
}
|
|
|
|
// First provider should have been skipped
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount)
|
|
}
|
|
|
|
// Should have called the second provider/model
|
|
if len(calledProviders) != 1 ||
|
|
calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model {
|
|
t.Errorf("expected second model to be called, got: %v", calledProviders)
|
|
}
|
|
|
|
// Verify first attempt was recorded as skipped
|
|
if len(result.Attempts) != 1 {
|
|
t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts))
|
|
}
|
|
|
|
if !result.Attempts[0].Skipped {
|
|
t.Error("expected first attempt to be marked as skipped")
|
|
}
|
|
}
|
|
|
|
// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable
|
|
func TestMultiKeyFailoverWithFormatError(t *testing.T) {
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Mock run function: first call fails with format error (bad request)
|
|
callCount := 0
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
return nil, errors.New("invalid request format: tool_use.id missing")
|
|
}
|
|
|
|
// Execute fallback chain
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error for format failure, got nil")
|
|
}
|
|
|
|
// Format errors should NOT trigger failover (non-retriable)
|
|
// So we should only have 1 call
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount)
|
|
}
|
|
|
|
// Verify the error is a FailoverError with format reason
|
|
var failoverErr *FailoverError
|
|
if !errors.As(err, &failoverErr) {
|
|
t.Errorf("expected FailoverError, got: %T - %v", err, err)
|
|
}
|
|
|
|
if failoverErr.Reason != FailoverFormat {
|
|
t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason)
|
|
}
|
|
|
|
_ = result // result should be nil
|
|
}
|
|
|
|
// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback.
|
|
// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"]
|
|
// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax
|
|
func TestMultiKeyWithModelFallback(t *testing.T) {
|
|
// Simulate expanded config from:
|
|
// { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] }
|
|
// After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"]
|
|
// Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax"
|
|
// In this test, we use the full format to avoid needing a lookup function.
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
// Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax)
|
|
if len(candidates) != 3 {
|
|
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
|
|
}
|
|
|
|
// Verify candidate order
|
|
if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" {
|
|
t.Errorf(
|
|
"expected first candidate to be zhipu/glm-4.7, got: %s/%s",
|
|
candidates[0].Provider,
|
|
candidates[0].Model,
|
|
)
|
|
}
|
|
if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" {
|
|
t.Errorf(
|
|
"expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s",
|
|
candidates[1].Provider,
|
|
candidates[1].Model,
|
|
)
|
|
}
|
|
if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" {
|
|
t.Errorf(
|
|
"expected third candidate to be minimax/minimax, got: %s/%s",
|
|
candidates[2].Provider,
|
|
candidates[2].Model,
|
|
)
|
|
}
|
|
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Mock run function: first two fail, third succeeds (model fallback)
|
|
callCount := 0
|
|
calledModels := []string{}
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
calledModels = append(calledModels, provider+"/"+model)
|
|
|
|
switch callCount {
|
|
case 1:
|
|
// k1: rate limit
|
|
return nil, errors.New("status: 429 - rate limit")
|
|
case 2:
|
|
// k2: also rate limit (all zhipu keys exhausted)
|
|
return nil, errors.New("status: 429 - rate limit")
|
|
case 3:
|
|
// minimax: success
|
|
return &LLMResponse{Content: "success from minimax"}, nil
|
|
default:
|
|
return nil, errors.New("unexpected call")
|
|
}
|
|
}
|
|
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
if err != nil {
|
|
t.Fatalf("expected success after failover to model fallback, got error: %v", err)
|
|
}
|
|
|
|
if callCount != 3 {
|
|
t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount)
|
|
}
|
|
|
|
if result.Response.Content != "success from minimax" {
|
|
t.Errorf("expected response from minimax, got: %s", result.Response.Content)
|
|
}
|
|
|
|
// Verify call order
|
|
if len(calledModels) != 3 {
|
|
t.Fatalf("expected 3 called models, got %d", len(calledModels))
|
|
}
|
|
if calledModels[0] != "zhipu/glm-4.7" {
|
|
t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0])
|
|
}
|
|
if calledModels[1] != "zhipu/glm-4.7__key_1" {
|
|
t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1])
|
|
}
|
|
if calledModels[2] != "minimax/minimax" {
|
|
t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2])
|
|
}
|
|
|
|
// Verify 2 failed attempts recorded
|
|
if len(result.Attempts) != 2 {
|
|
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
|
|
}
|
|
|
|
// Both should be rate limit
|
|
for i, attempt := range result.Attempts {
|
|
if attempt.Reason != FailoverRateLimit {
|
|
t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMultiKeyFailoverMixedErrors tests failover with different error types
|
|
func TestMultiKeyFailoverMixedErrors(t *testing.T) {
|
|
cfg := ModelConfig{
|
|
Primary: "glm-4.7",
|
|
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
|
}
|
|
|
|
candidates := ResolveCandidates(cfg, "zhipu")
|
|
|
|
cooldown := NewCooldownTracker()
|
|
chain := NewFallbackChain(cooldown)
|
|
|
|
// Mock run function: different errors for each key
|
|
callCount := 0
|
|
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
|
callCount++
|
|
switch callCount {
|
|
case 1:
|
|
// First: rate limit (retriable)
|
|
return nil, errors.New("status: 429 - rate limit")
|
|
case 2:
|
|
// Second: timeout (retriable)
|
|
return nil, errors.New("context deadline exceeded")
|
|
case 3:
|
|
// Third: success
|
|
return &LLMResponse{Content: "success from key3"}, nil
|
|
default:
|
|
return nil, errors.New("unexpected call")
|
|
}
|
|
}
|
|
|
|
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
|
if err != nil {
|
|
t.Fatalf("expected success after 2 failovers, got error: %v", err)
|
|
}
|
|
|
|
if callCount != 3 {
|
|
t.Errorf("expected 3 calls, got %d", callCount)
|
|
}
|
|
|
|
// Verify both failed attempts were recorded
|
|
if len(result.Attempts) != 2 {
|
|
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
|
|
}
|
|
|
|
// First should be rate limit
|
|
if result.Attempts[0].Reason != FailoverRateLimit {
|
|
t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason)
|
|
}
|
|
|
|
// Second should be timeout
|
|
if result.Attempts[1].Reason != FailoverTimeout {
|
|
t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason)
|
|
}
|
|
}
|