From 6e7149509a3a4a25661604a6def08a3f532b0b87 Mon Sep 17 00:00:00 2001 From: Leandro Barbosa Date: Fri, 13 Feb 2026 12:12:12 -0300 Subject: [PATCH] feat: add model fallback chain with error classification Add 2-layer fallback system (text + image) with automatic candidate resolution. Includes error classifier (~40 patterns), per-provider cooldown (exponential backoff), and model reference parsing. - FailoverError/FailoverReason types for structured error handling - ErrorClassifier with rate_limit, billing, auth, timeout patterns - FallbackChain with cooldown management and candidate rotation - ModelRef parser for provider/model string format - 128 tests, 95%+ coverage --- pkg/providers/cooldown.go | 207 +++++++++++ pkg/providers/cooldown_test.go | 269 ++++++++++++++ pkg/providers/error_classifier.go | 253 +++++++++++++ pkg/providers/error_classifier_test.go | 337 ++++++++++++++++++ pkg/providers/fallback.go | 283 +++++++++++++++ pkg/providers/fallback_test.go | 473 +++++++++++++++++++++++++ pkg/providers/model_ref.go | 64 ++++ pkg/providers/model_ref_test.go | 125 +++++++ pkg/providers/types.go | 48 ++- 9 files changed, 2058 insertions(+), 1 deletion(-) create mode 100644 pkg/providers/cooldown.go create mode 100644 pkg/providers/cooldown_test.go create mode 100644 pkg/providers/error_classifier.go create mode 100644 pkg/providers/error_classifier_test.go create mode 100644 pkg/providers/fallback.go create mode 100644 pkg/providers/fallback_test.go create mode 100644 pkg/providers/model_ref.go create mode 100644 pkg/providers/model_ref_test.go diff --git a/pkg/providers/cooldown.go b/pkg/providers/cooldown.go new file mode 100644 index 000000000..6811297f0 --- /dev/null +++ b/pkg/providers/cooldown.go @@ -0,0 +1,207 @@ +package providers + +import ( + "math" + "sync" + "time" +) + +const ( + defaultFailureWindow = 24 * time.Hour +) + +// CooldownTracker manages per-provider cooldown state for the fallback chain. +// Thread-safe via sync.RWMutex. In-memory only (resets on restart). +type CooldownTracker struct { + mu sync.RWMutex + entries map[string]*cooldownEntry + failureWindow time.Duration + nowFunc func() time.Time // for testing +} + +type cooldownEntry struct { + ErrorCount int + FailureCounts map[FailoverReason]int + CooldownEnd time.Time // standard cooldown expiry + DisabledUntil time.Time // billing-specific disable expiry + DisabledReason FailoverReason // reason for disable (billing) + LastFailure time.Time +} + +// NewCooldownTracker creates a tracker with default 24h failure window. +func NewCooldownTracker() *CooldownTracker { + return &CooldownTracker{ + entries: make(map[string]*cooldownEntry), + failureWindow: defaultFailureWindow, + nowFunc: time.Now, + } +} + +// MarkFailure records a failure for a provider and sets appropriate cooldown. +// Resets error counts if last failure was more than failureWindow ago. +func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) { + ct.mu.Lock() + defer ct.mu.Unlock() + + now := ct.nowFunc() + entry := ct.getOrCreate(provider) + + // 24h failure window reset: if no failure in failureWindow, reset counters. + if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow { + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + } + + entry.ErrorCount++ + entry.FailureCounts[reason]++ + entry.LastFailure = now + + if reason == FailoverBilling { + billingCount := entry.FailureCounts[FailoverBilling] + entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount)) + entry.DisabledReason = FailoverBilling + } else { + entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount)) + } +} + +// MarkSuccess resets all counters and cooldowns for a provider. +func (ct *CooldownTracker) MarkSuccess(provider string) { + ct.mu.Lock() + defer ct.mu.Unlock() + + entry := ct.entries[provider] + if entry == nil { + return + } + + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + entry.CooldownEnd = time.Time{} + entry.DisabledUntil = time.Time{} + entry.DisabledReason = "" +} + +// IsAvailable returns true if the provider is not in cooldown or disabled. +func (ct *CooldownTracker) IsAvailable(provider string) bool { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return true + } + + now := ct.nowFunc() + + // Billing disable takes precedence (longer cooldown). + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + return false + } + + // Standard cooldown. + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + return false + } + + return true +} + +// CooldownRemaining returns how long until the provider becomes available. +// Returns 0 if already available. +func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + + now := ct.nowFunc() + var remaining time.Duration + + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + d := entry.DisabledUntil.Sub(now) + if d > remaining { + remaining = d + } + } + + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + d := entry.CooldownEnd.Sub(now) + if d > remaining { + remaining = d + } + } + + return remaining +} + +// ErrorCount returns the current error count for a provider. +func (ct *CooldownTracker) ErrorCount(provider string) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.ErrorCount +} + +// FailureCount returns the failure count for a specific reason. +func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.FailureCounts[reason] +} + +func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry { + entry := ct.entries[provider] + if entry == nil { + entry = &cooldownEntry{ + FailureCounts: make(map[FailoverReason]int), + } + ct.entries[provider] = entry + } + return entry +} + +// calculateStandardCooldown computes standard exponential backoff. +// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3)) +// +// 1 error → 1 min +// 2 errors → 5 min +// 3 errors → 25 min +// 4+ errors → 1 hour (cap) +func calculateStandardCooldown(errorCount int) time.Duration { + n := max(1, errorCount) + exp := min(n-1, 3) + ms := 60_000 * int(math.Pow(5, float64(exp))) + ms = min(3_600_000, ms) // cap at 1 hour + return time.Duration(ms) * time.Millisecond +} + +// calculateBillingCooldown computes billing-specific exponential backoff. +// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10)) +// +// 1 error → 5 hours +// 2 errors → 10 hours +// 3 errors → 20 hours +// 4+ errors → 24 hours (cap) +func calculateBillingCooldown(billingErrorCount int) time.Duration { + const baseMs = 5 * 60 * 60 * 1000 // 5 hours + const maxMs = 24 * 60 * 60 * 1000 // 24 hours + + n := max(1, billingErrorCount) + exp := min(n-1, 10) + raw := float64(baseMs) * math.Pow(2, float64(exp)) + ms := int(math.Min(float64(maxMs), raw)) + return time.Duration(ms) * time.Millisecond +} diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go new file mode 100644 index 000000000..e51ff40e5 --- /dev/null +++ b/pkg/providers/cooldown_test.go @@ -0,0 +1,269 @@ +package providers + +import ( + "sync" + "testing" + "time" +) + +func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) { + current := now + ct := NewCooldownTracker() + ct.nowFunc = func() time.Time { return current } + return ct, ¤t +} + +func TestCooldown_InitiallyAvailable(t *testing.T) { + ct := NewCooldownTracker() + if !ct.IsAvailable("openai") { + t.Error("new provider should be available") + } + if ct.ErrorCount("openai") != 0 { + t.Error("new provider should have 0 errors") + } +} + +func TestCooldown_StandardEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st error → 1 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown after 1st error") + } + + // Advance 61 seconds → available + *current = now.Add(61 * time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 min cooldown") + } + + // 2nd error → 5 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + *current = now.Add(61*time.Second + 4*time.Minute) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown (5 min) after 2nd error") + } + *current = now.Add(61*time.Second + 6*time.Minute) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5 min cooldown") + } +} + +func TestCooldown_StandardCap(t *testing.T) { + // Verify formula: 1m, 5m, 25m, 1h, 1h, 1h... + expected := []time.Duration{ + 1 * time.Minute, + 5 * time.Minute, + 25 * time.Minute, + 1 * time.Hour, + 1 * time.Hour, + } + + for i, want := range expected { + got := calculateStandardCooldown(i + 1) + if got != want { + t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_BillingEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st billing error → 5h cooldown + ct.MarkFailure("openai", FailoverBilling) + if ct.IsAvailable("openai") { + t.Error("should be disabled after billing error") + } + + // Advance 4h → still disabled + *current = now.Add(4 * time.Hour) + if ct.IsAvailable("openai") { + t.Error("should still be disabled (5h cooldown)") + } + + // Advance 5h + 1s → available + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5h billing cooldown") + } +} + +func TestCooldown_BillingCap(t *testing.T) { + expected := []time.Duration{ + 5 * time.Hour, + 10 * time.Hour, + 20 * time.Hour, + 24 * time.Hour, + 24 * time.Hour, + } + + for i, want := range expected { + got := calculateBillingCooldown(i + 1) + if got != want { + t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_SuccessReset(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + if ct.ErrorCount("openai") != 2 { + t.Errorf("error count = %d, want 2", ct.ErrorCount("openai")) + } + + ct.MarkSuccess("openai") + if ct.ErrorCount("openai") != 0 { + t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai")) + } + if !ct.IsAvailable("openai") { + t.Error("should be available after success") + } + if ct.FailureCount("openai", FailoverRateLimit) != 0 { + t.Error("failure counts should be reset after success") + } + if ct.FailureCount("openai", FailoverBilling) != 0 { + t.Error("billing failure count should be reset after success") + } +} + +func TestCooldown_FailureWindowReset(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 4 errors → 1h cooldown + for i := 0; i < 4; i++ { + ct.MarkFailure("openai", FailoverRateLimit) + *current = current.Add(2 * time.Second) // small advance between errors + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("error count = %d, want 4", ct.ErrorCount("openai")) + } + + // Advance 25 hours (past 24h failure window) + *current = now.Add(25 * time.Hour) + + // Next error should reset counters first, then increment to 1 + ct.MarkFailure("openai", FailoverRateLimit) + if ct.ErrorCount("openai") != 1 { + t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai")) + } +} + +func TestCooldown_PerReasonTracking(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + ct.MarkFailure("openai", FailoverAuth) + + if ct.FailureCount("openai", FailoverRateLimit) != 2 { + t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit)) + } + if ct.FailureCount("openai", FailoverBilling) != 1 { + t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling)) + } + if ct.FailureCount("openai", FailoverAuth) != 1 { + t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth)) + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai")) + } +} + +func TestCooldown_BillingTakesPrecedence(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // Standard cooldown (1 min) + billing disable (5h) + ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown + ct.MarkFailure("openai", FailoverBilling) // 5h disable + + // After 2 min: standard cooldown expired but billing still active + *current = now.Add(2 * time.Minute) + if ct.IsAvailable("openai") { + t.Error("billing disable should take precedence over standard cooldown") + } + + // After 5h + 1s: both expired + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after all cooldowns expire") + } +} + +func TestCooldown_CooldownRemaining(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // No failures → 0 remaining + if ct.CooldownRemaining("openai") != 0 { + t.Error("expected 0 remaining for new provider") + } + + ct.MarkFailure("openai", FailoverRateLimit) + + *current = now.Add(30 * time.Second) + remaining := ct.CooldownRemaining("openai") + if remaining <= 0 || remaining > 1*time.Minute { + t.Errorf("remaining = %v, expected ~30s", remaining) + } +} + +func TestCooldown_SuccessOnUnknownProvider(t *testing.T) { + ct := NewCooldownTracker() + // Should not panic + ct.MarkSuccess("nonexistent") + if !ct.IsAvailable("nonexistent") { + t.Error("nonexistent provider should be available") + } +} + +func TestCooldown_ConcurrentAccess(t *testing.T) { + ct := NewCooldownTracker() + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(3) + go func() { + defer wg.Done() + ct.MarkFailure("openai", FailoverRateLimit) + }() + go func() { + defer wg.Done() + ct.IsAvailable("openai") + }() + go func() { + defer wg.Done() + ct.MarkSuccess("openai") + }() + } + + wg.Wait() + // If we got here without panic, concurrent access is safe +} + +func TestCooldown_MultipleProviders(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + if ct.IsAvailable("openai") { + t.Error("openai should be in cooldown") + } + if ct.IsAvailable("anthropic") { + t.Error("anthropic should be in cooldown") + } + // groq was never touched + if !ct.IsAvailable("groq") { + t.Error("groq should be available") + } +} diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go new file mode 100644 index 000000000..a0f003006 --- /dev/null +++ b/pkg/providers/error_classifier.go @@ -0,0 +1,253 @@ +package providers + +import ( + "context" + "regexp" + "strings" +) + +// errorPattern defines a single pattern (string or regex) for error classification. +type errorPattern struct { + substring string + regex *regexp.Regexp +} + +func substr(s string) errorPattern { return errorPattern{substring: s} } +func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} } + +// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns). +var ( + rateLimitPatterns = []errorPattern{ + rxp(`rate[_ ]limit`), + substr("too many requests"), + substr("429"), + substr("exceeded your current quota"), + rxp(`exceeded.*quota`), + rxp(`resource has been exhausted`), + rxp(`resource.*exhausted`), + substr("resource_exhausted"), + substr("quota exceeded"), + substr("usage limit"), + } + + overloadedPatterns = []errorPattern{ + rxp(`overloaded_error`), + rxp(`"type"\s*:\s*"overloaded_error"`), + substr("overloaded"), + } + + timeoutPatterns = []errorPattern{ + substr("timeout"), + substr("timed out"), + substr("deadline exceeded"), + substr("context deadline exceeded"), + } + + billingPatterns = []errorPattern{ + rxp(`\b402\b`), + substr("payment required"), + substr("insufficient credits"), + substr("credit balance"), + substr("plans & billing"), + substr("insufficient balance"), + } + + authPatterns = []errorPattern{ + rxp(`invalid[_ ]?api[_ ]?key`), + substr("incorrect api key"), + substr("invalid token"), + substr("authentication"), + substr("re-authenticate"), + substr("oauth token refresh failed"), + substr("unauthorized"), + substr("forbidden"), + substr("access denied"), + substr("expired"), + substr("token has expired"), + rxp(`\b401\b`), + rxp(`\b403\b`), + substr("no credentials found"), + substr("no api key found"), + } + + formatPatterns = []errorPattern{ + substr("string should match pattern"), + substr("tool_use.id"), + substr("tool_use_id"), + substr("messages.1.content.1.tool_use.id"), + substr("invalid request format"), + } + + imageDimensionPatterns = []errorPattern{ + rxp(`image dimensions exceed max`), + } + + imageSizePatterns = []errorPattern{ + rxp(`image exceeds.*mb`), + } + + // Transient HTTP status codes that map to timeout (server-side failures). + transientStatusCodes = map[int]bool{ + 500: true, 502: true, 503: true, + 521: true, 522: true, 523: true, 524: true, + 529: true, + } +) + +// ClassifyError classifies an error into a FailoverError with reason. +// Returns nil if the error is not classifiable (unknown errors should not trigger fallback). +func ClassifyError(err error, provider, model string) *FailoverError { + if err == nil { + return nil + } + + // Context cancellation: user abort, never fallback. + if err == context.Canceled { + return nil + } + + // Context deadline exceeded: treat as timeout, always fallback. + if err == context.DeadlineExceeded { + return &FailoverError{ + Reason: FailoverTimeout, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + msg := strings.ToLower(err.Error()) + + // Image dimension/size errors: non-retriable, non-fallback. + if IsImageDimensionError(msg) || IsImageSizeError(msg) { + return &FailoverError{ + Reason: FailoverFormat, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + // Try HTTP status code extraction first. + if status := extractHTTPStatus(msg); status > 0 { + if reason := classifyByStatus(status); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Status: status, + Wrapped: err, + } + } + } + + // Message pattern matching (priority order from OpenClaw). + if reason := classifyByMessage(msg); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + return nil +} + +// classifyByStatus maps HTTP status codes to FailoverReason. +func classifyByStatus(status int) FailoverReason { + switch { + case status == 401 || status == 403: + return FailoverAuth + case status == 402: + return FailoverBilling + case status == 408: + return FailoverTimeout + case status == 429: + return FailoverRateLimit + case status == 400: + return FailoverFormat + case transientStatusCodes[status]: + return FailoverTimeout + } + return "" +} + +// classifyByMessage matches error messages against patterns. +// Priority order matters (from OpenClaw classifyFailoverReason). +func classifyByMessage(msg string) FailoverReason { + if matchesAny(msg, rateLimitPatterns) { + return FailoverRateLimit + } + if matchesAny(msg, overloadedPatterns) { + return FailoverRateLimit // Overloaded treated as rate_limit + } + if matchesAny(msg, billingPatterns) { + return FailoverBilling + } + if matchesAny(msg, timeoutPatterns) { + return FailoverTimeout + } + if matchesAny(msg, authPatterns) { + return FailoverAuth + } + if matchesAny(msg, formatPatterns) { + return FailoverFormat + } + return "" +} + +// extractHTTPStatus extracts an HTTP status code from an error message. +// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429". +func extractHTTPStatus(msg string) int { + // Common patterns in Go HTTP error messages + patterns := []*regexp.Regexp{ + regexp.MustCompile(`status[:\s]+(\d{3})`), + regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`), + } + + for _, p := range patterns { + if m := p.FindStringSubmatch(msg); len(m) > 1 { + return parseDigits(m[1]) + } + } + + return 0 +} + +// IsImageDimensionError returns true if the message indicates an image dimension error. +func IsImageDimensionError(msg string) bool { + return matchesAny(msg, imageDimensionPatterns) +} + +// IsImageSizeError returns true if the message indicates an image file size error. +func IsImageSizeError(msg string) bool { + return matchesAny(msg, imageSizePatterns) +} + +// matchesAny checks if msg matches any of the patterns. +func matchesAny(msg string, patterns []errorPattern) bool { + for _, p := range patterns { + if p.regex != nil { + if p.regex.MatchString(msg) { + return true + } + } else if p.substring != "" { + if strings.Contains(msg, p.substring) { + return true + } + } + } + return false +} + +// parseDigits converts a string of digits to an int. +func parseDigits(s string) int { + n := 0 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n +} diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go new file mode 100644 index 000000000..865aea57a --- /dev/null +++ b/pkg/providers/error_classifier_test.go @@ -0,0 +1,337 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestClassifyError_Nil(t *testing.T) { + result := ClassifyError(nil, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for nil error, got %+v", result) + } +} + +func TestClassifyError_ContextCanceled(t *testing.T) { + result := ClassifyError(context.Canceled, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for context.Canceled (user abort), got %+v", result) + } +} + +func TestClassifyError_ContextDeadlineExceeded(t *testing.T) { + result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil for deadline exceeded") + } + if result.Reason != FailoverTimeout { + t.Errorf("reason = %q, want timeout", result.Reason) + } +} + +func TestClassifyError_StatusCodes(t *testing.T) { + tests := []struct { + status int + reason FailoverReason + }{ + {401, FailoverAuth}, + {403, FailoverAuth}, + {402, FailoverBilling}, + {408, FailoverTimeout}, + {429, FailoverRateLimit}, + {400, FailoverFormat}, + {500, FailoverTimeout}, + {502, FailoverTimeout}, + {503, FailoverTimeout}, + {521, FailoverTimeout}, + {522, FailoverTimeout}, + {523, FailoverTimeout}, + {524, FailoverTimeout}, + {529, FailoverTimeout}, + } + + for _, tt := range tests { + err := fmt.Errorf("API error: status: %d something went wrong", tt.status) + result := ClassifyError(err, "test", "model") + if result == nil { + t.Errorf("status %d: expected non-nil", tt.status) + continue + } + if result.Reason != tt.reason { + t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason) + } + } +} + +func TestClassifyError_RateLimitPatterns(t *testing.T) { + patterns := []string{ + "rate limit exceeded", + "rate_limit reached", + "too many requests", + "exceeded your current quota", + "resource has been exhausted", + "resource_exhausted", + "quota exceeded", + "usage limit reached", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_OverloadedPatterns(t *testing.T) { + patterns := []string{ + "overloaded_error", + `{"type": "overloaded_error"}`, + "server is overloaded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + // Overloaded is treated as rate_limit + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_BillingPatterns(t *testing.T) { + patterns := []string{ + "payment required", + "insufficient credits", + "credit balance too low", + "plans & billing page", + "insufficient balance", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverBilling { + t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason) + } + } +} + +func TestClassifyError_TimeoutPatterns(t *testing.T) { + patterns := []string{ + "request timeout", + "connection timed out", + "deadline exceeded", + "context deadline exceeded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + +func TestClassifyError_AuthPatterns(t *testing.T) { + patterns := []string{ + "invalid api key", + "invalid_api_key", + "incorrect api key", + "invalid token", + "authentication failed", + "re-authenticate", + "oauth token refresh failed", + "unauthorized access", + "forbidden", + "access denied", + "expired", + "token has expired", + "no credentials found", + "no api key found", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverAuth { + t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason) + } + } +} + +func TestClassifyError_FormatPatterns(t *testing.T) { + patterns := []string{ + "string should match pattern", + "tool_use.id is required", + "invalid tool_use_id", + "messages.1.content.1.tool_use.id must be valid", + "invalid request format", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverFormat { + t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason) + } + } +} + +func TestClassifyError_ImageDimensionError(t *testing.T) { + err := errors.New("image dimensions exceed max allowed 2048x2048") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image dimension error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } + if result.IsRetriable() { + t.Error("image dimension error should not be retriable") + } +} + +func TestClassifyError_ImageSizeError(t *testing.T) { + err := errors.New("image exceeds 20 mb limit") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image size error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } +} + +func TestClassifyError_UnknownError(t *testing.T) { + err := errors.New("some completely random error") + result := ClassifyError(err, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for unknown error, got %+v", result) + } +} + +func TestClassifyError_ProviderModelPropagation(t *testing.T) { + err := errors.New("rate limit exceeded") + result := ClassifyError(err, "my-provider", "my-model") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Provider != "my-provider" { + t.Errorf("provider = %q, want my-provider", result.Provider) + } + if result.Model != "my-model" { + t.Errorf("model = %q, want my-model", result.Model) + } +} + +func TestFailoverError_IsRetriable(t *testing.T) { + tests := []struct { + reason FailoverReason + retriable bool + }{ + {FailoverAuth, true}, + {FailoverRateLimit, true}, + {FailoverBilling, true}, + {FailoverTimeout, true}, + {FailoverOverloaded, true}, + {FailoverFormat, false}, + {FailoverUnknown, true}, + } + + for _, tt := range tests { + fe := &FailoverError{Reason: tt.reason} + if fe.IsRetriable() != tt.retriable { + t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable) + } + } +} + +func TestFailoverError_ErrorString(t *testing.T) { + fe := &FailoverError{ + Reason: FailoverRateLimit, + Provider: "openai", + Model: "gpt-4", + Status: 429, + Wrapped: errors.New("too many requests"), + } + s := fe.Error() + if s == "" { + t.Error("expected non-empty error string") + } +} + +func TestFailoverError_Unwrap(t *testing.T) { + inner := errors.New("inner error") + fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner} + if fe.Unwrap() != inner { + t.Error("Unwrap should return wrapped error") + } +} + +func TestExtractHTTPStatus(t *testing.T) { + tests := []struct { + msg string + want int + }{ + {"status: 429 rate limited", 429}, + {"status 401 unauthorized", 401}, + {"HTTP/1.1 502 Bad Gateway", 502}, + {"no status code here", 0}, + {"random number 12345", 0}, + } + + for _, tt := range tests { + got := extractHTTPStatus(tt.msg) + if got != tt.want { + t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want) + } + } +} + +func TestIsImageDimensionError(t *testing.T) { + if !IsImageDimensionError("image dimensions exceed max 4096x4096") { + t.Error("should match image dimensions exceed max") + } + if IsImageDimensionError("normal error message") { + t.Error("should not match normal error") + } +} + +func TestIsImageSizeError(t *testing.T) { + if !IsImageSizeError("image exceeds 20 mb") { + t.Error("should match image exceeds mb") + } + if IsImageSizeError("normal error message") { + t.Error("should not match normal error") + } +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go new file mode 100644 index 000000000..9b07f9153 --- /dev/null +++ b/pkg/providers/fallback.go @@ -0,0 +1,283 @@ +package providers + +import ( + "context" + "fmt" + "strings" + "time" +) + +// FallbackChain orchestrates model fallback across multiple candidates. +type FallbackChain struct { + cooldown *CooldownTracker +} + +// FallbackCandidate represents one model/provider to try. +type FallbackCandidate struct { + Provider string + Model string +} + +// FallbackResult contains the successful response and metadata about all attempts. +type FallbackResult struct { + Response *LLMResponse + Provider string + Model string + Attempts []FallbackAttempt +} + +// FallbackAttempt records one attempt in the fallback chain. +type FallbackAttempt struct { + Provider string + Model string + Error error + Reason FailoverReason + Duration time.Duration + Skipped bool // true if skipped due to cooldown +} + +// NewFallbackChain creates a new fallback chain with the given cooldown tracker. +func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { + return &FallbackChain{cooldown: cooldown} +} + +// ResolveCandidates parses model config into a deduplicated candidate list. +func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate { + seen := make(map[string]bool) + var candidates []FallbackCandidate + + addCandidate := func(raw string) { + ref := ParseModelRef(raw, defaultProvider) + if ref == nil { + return + } + key := ModelKey(ref.Provider, ref.Model) + if seen[key] { + return + } + seen[key] = true + candidates = append(candidates, FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + }) + } + + // Primary first. + addCandidate(cfg.Primary) + + // Then fallbacks. + for _, fb := range cfg.Fallbacks { + addCandidate(fb) + } + + return candidates +} + +// Execute runs the fallback chain for text/chat requests. +// It tries each candidate in order, respecting cooldowns and error classification. +// +// Behavior: +// - Candidates in cooldown are skipped (logged as skipped attempt). +// - context.Canceled aborts immediately (user abort, no fallback). +// - Non-retriable errors (format) abort immediately. +// - Retriable errors trigger fallback to next candidate. +// - Success marks provider as good (resets cooldown). +// - If all fail, returns aggregate error with all attempts. +func (fc *FallbackChain) Execute( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + // Check context before each attempt. + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + // Check cooldown. + if !fc.cooldown.IsAvailable(candidate.Provider) { + remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)), + }) + continue + } + + // Execute the run function. + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + // Success. + fc.cooldown.MarkSuccess(candidate.Provider) + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + // Context cancellation: abort immediately, no fallback. + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Classify the error. + failErr := ClassifyError(err, candidate.Provider, candidate.Model) + + if failErr == nil { + // Unclassifiable error: do not fallback, return immediately. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w", + candidate.Provider, candidate.Model, err) + } + + // Non-retriable error: abort immediately. + if !failErr.IsRetriable() { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + return nil, failErr + } + + // Retriable error: mark failure and continue to next candidate. + fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + + // If this was the last candidate, return aggregate error. + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + // All candidates were skipped (all in cooldown). + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// ExecuteImage runs the fallback chain for image/vision requests. +// Simpler than Execute: no cooldown checks (image endpoints have different rate limits). +// Image dimension/size errors abort immediately (non-retriable). +func (fc *FallbackChain) ExecuteImage( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("image fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Image dimension/size errors are non-retriable. + errMsg := strings.ToLower(err.Error()) + if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Reason: FailoverFormat, + Duration: elapsed, + }) + return nil, &FailoverError{ + Reason: FailoverFormat, + Provider: candidate.Provider, + Model: candidate.Model, + Wrapped: err, + } + } + + // Any other error: record and try next. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// FallbackExhaustedError indicates all fallback candidates were tried and failed. +type FallbackExhaustedError struct { + Attempts []FallbackAttempt +} + +func (e *FallbackExhaustedError) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts))) + for i, a := range e.Attempts { + if a.Skipped { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model)) + } else { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)", + i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond))) + } + } + return sb.String() +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go new file mode 100644 index 000000000..ea81e0d48 --- /dev/null +++ b/pkg/providers/fallback_test.go @@ -0,0 +1,473 @@ +package providers + +import ( + "context" + "errors" + "testing" + "time" +) + +func makeCandidate(provider, model string) FallbackCandidate { + return FallbackCandidate{Provider: provider, Model: model} +} + +func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return &LLMResponse{Content: content, FinishReason: "stop"}, nil + } +} + +func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, err + } +} + +func TestFallback_SingleCandidate_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("hello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "hello" { + t.Errorf("content = %q, want hello", result.Response.Content) + } + if result.Provider != "openai" || result.Model != "gpt-4" { + t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model) + } +} + +func TestFallback_SecondCandidateSuccess(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude-opus"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + if result.Response.Content != "from claude" { + t.Errorf("content = %q, want 'from claude'", result.Response.Content) + } + if len(result.Attempts) != 1 { + t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts)) + } +} + +func TestFallback_AllFail(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + makeCandidate("groq", "llama"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, errors.New("rate limit exceeded") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err) + } + if len(exhausted.Attempts) != 3 { + t.Errorf("attempts = %d, want 3", len(exhausted.Attempts)) + } +} + +func TestFallback_ContextCanceled(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + ctx, cancel := context.WithCancel(context.Background()) + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + cancel() // cancel context + return nil, context.Canceled + } + t.Error("should not reach second candidate after cancel") + return nil, nil + } + + _, err := fc.Execute(ctx, candidates, run) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestFallback_NonRetriableError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("string should match pattern") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for non-retriable") + } + var fe *FailoverError + if !errors.As(err, &fe) { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", fe.Reason) + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt) + } +} + +func TestFallback_CooldownSkip(t *testing.T) { + now := time.Now() + ct, _ := newTestTracker(now) + fc := NewFallbackChain(ct) + + // Put openai in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + if provider == "openai" { + t.Error("should not call openai (in cooldown)") + } + return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + // Should have 1 skipped attempt + skipped := 0 + for _, a := range result.Attempts { + if a.Skipped { + skipped++ + } + } + if skipped != 1 { + t.Errorf("skipped = %d, want 1", skipped) + } +} + +func TestFallback_AllInCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + // Put all providers in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + _, err := fc.Execute(context.Background(), candidates, + func(ctx context.Context, provider, model string) (*LLMResponse, error) { + t.Error("should not call any provider (all in cooldown)") + return nil, nil + }) + + if err == nil { + t.Fatal("expected error when all in cooldown") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Fatalf("expected FallbackExhaustedError, got %T", err) + } +} + +func TestFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.Execute(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +func TestFallback_EmptyFallbacks(t *testing.T) { + // Single primary, no fallbacks: should work like direct call + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("ok")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "ok" { + t.Error("expected success with single candidate") + } +} + +func TestFallback_UnclassifiedError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("completely unknown internal error") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for unclassified error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt) + } +} + +func TestFallback_SuccessResetsCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + } + return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ct.IsAvailable("openai") { + t.Error("success should reset cooldown") + } +} + +// --- Image Fallback Tests --- + +func TestImageFallback_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")} + result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "image result" { + t.Error("expected image result") + } +} + +func TestImageFallback_DimensionError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image dimensions exceed max 4096x4096") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image dimension error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt) + } +} + +func TestImageFallback_SizeError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image exceeds 20 mb") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image size error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt) + } +} + +func TestImageFallback_RetryOnOtherErrors(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude-sonnet"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil + } + + result, err := fc.ExecuteImage(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } +} + +func TestImageFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.ExecuteImage(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +// --- ResolveCandidates Tests --- + +func TestResolveCandidates_Simple(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 3 { + t.Fatalf("candidates = %d, want 3", len(candidates)) + } + + if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" { + t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model) + } + if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" { + t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model) + } + if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" { + t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model) + } +} + +func TestResolveCandidates_Deduplication(t *testing.T) { + cfg := ModelConfig{ + Primary: "openai/gpt-4", + Fallbacks: []string{"openai/gpt-4", "anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "default") + if len(candidates) != 2 { + t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates)) + } +} + +func TestResolveCandidates_EmptyFallbacks(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: nil, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestResolveCandidates_EmptyPrimary(t *testing.T) { + cfg := ModelConfig{ + Primary: "", + Fallbacks: []string{"anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestFallbackExhaustedError_Message(t *testing.T) { + e := &FallbackExhaustedError{ + Attempts: []FallbackAttempt{ + {Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond}, + {Provider: "anthropic", Model: "claude", Skipped: true}, + }, + } + msg := e.Error() + if msg == "" { + t.Error("expected non-empty error message") + } +} diff --git a/pkg/providers/model_ref.go b/pkg/providers/model_ref.go new file mode 100644 index 000000000..0d1b02d16 --- /dev/null +++ b/pkg/providers/model_ref.go @@ -0,0 +1,64 @@ +package providers + +import "strings" + +// ModelRef represents a parsed model reference with provider and model name. +type ModelRef struct { + Provider string + Model string +} + +// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}. +// If no slash present, uses defaultProvider. +// Returns nil for empty input. +func ParseModelRef(raw string, defaultProvider string) *ModelRef { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + if idx := strings.Index(raw, "/"); idx > 0 { + provider := NormalizeProvider(raw[:idx]) + model := strings.TrimSpace(raw[idx+1:]) + if model == "" { + return nil + } + return &ModelRef{Provider: provider, Model: model} + } + + return &ModelRef{ + Provider: NormalizeProvider(defaultProvider), + Model: raw, + } +} + +// NormalizeProvider normalizes provider identifiers to canonical form. +func NormalizeProvider(provider string) string { + p := strings.ToLower(strings.TrimSpace(provider)) + + switch p { + case "z.ai", "z-ai": + return "zai" + case "opencode-zen": + return "opencode" + case "qwen": + return "qwen-portal" + case "kimi-code": + return "kimi-coding" + case "gpt": + return "openai" + case "claude": + return "anthropic" + case "glm": + return "zhipu" + case "google": + return "gemini" + } + + return p +} + +// ModelKey returns a canonical "provider/model" key for deduplication. +func ModelKey(provider, model string) string { + return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model)) +} diff --git a/pkg/providers/model_ref_test.go b/pkg/providers/model_ref_test.go new file mode 100644 index 000000000..6dd25167f --- /dev/null +++ b/pkg/providers/model_ref_test.go @@ -0,0 +1,125 @@ +package providers + +import "testing" + +func TestParseModelRef_WithSlash(t *testing.T) { + ref := ParseModelRef("anthropic/claude-opus", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestParseModelRef_WithoutSlash(t *testing.T) { + ref := ParseModelRef("gpt-4", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai", ref.Provider) + } + if ref.Model != "gpt-4" { + t.Errorf("model = %q, want gpt-4", ref.Model) + } +} + +func TestParseModelRef_Empty(t *testing.T) { + ref := ParseModelRef("", "openai") + if ref != nil { + t.Errorf("expected nil for empty string, got %+v", ref) + } +} + +func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) { + ref := ParseModelRef("openai/", "default") + if ref != nil { + t.Errorf("expected nil for empty model, got %+v", ref) + } +} + +func TestParseModelRef_WhitespaceHandling(t *testing.T) { + ref := ParseModelRef(" anthropic / claude-opus ", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestNormalizeProvider(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"OpenAI", "openai"}, + {"ANTHROPIC", "anthropic"}, + {"z.ai", "zai"}, + {"z-ai", "zai"}, + {"Z.AI", "zai"}, + {"opencode-zen", "opencode"}, + {"qwen", "qwen-portal"}, + {"kimi-code", "kimi-coding"}, + {"gpt", "openai"}, + {"claude", "anthropic"}, + {"glm", "zhipu"}, + {"google", "gemini"}, + {"groq", "groq"}, + {"", ""}, + } + + for _, tt := range tests { + got := NormalizeProvider(tt.input) + if got != tt.want { + t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestModelKey(t *testing.T) { + tests := []struct { + provider string + model string + want string + }{ + {"openai", "gpt-4", "openai/gpt-4"}, + {"Anthropic", "Claude-Opus", "anthropic/claude-opus"}, + {"claude", "sonnet", "anthropic/sonnet"}, + {"z.ai", "Model-X", "zai/model-x"}, + } + + for _, tt := range tests { + got := ModelKey(tt.provider, tt.model) + if got != tt.want { + t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want) + } + } +} + +func TestParseModelRef_ProviderNormalization(t *testing.T) { + ref := ParseModelRef("Z.AI/model-x", "default") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "zai" { + t.Errorf("provider = %q, want zai", ref.Provider) + } +} + +func TestParseModelRef_DefaultProviderNormalization(t *testing.T) { + ref := ParseModelRef("gpt-4o", "GPT") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider) + } +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 88b62e975..aa30a1a46 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,6 +1,9 @@ package providers -import "context" +import ( + "context" + "fmt" +) type ToolCall struct { ID string `json:"id"` @@ -40,6 +43,49 @@ type LLMProvider interface { GetDefaultModel() string } +// FailoverReason classifies why an LLM request failed for fallback decisions. +type FailoverReason string + +const ( + FailoverAuth FailoverReason = "auth" + FailoverRateLimit FailoverReason = "rate_limit" + FailoverBilling FailoverReason = "billing" + FailoverTimeout FailoverReason = "timeout" + FailoverFormat FailoverReason = "format" + FailoverOverloaded FailoverReason = "overloaded" + FailoverUnknown FailoverReason = "unknown" +) + +// FailoverError wraps an LLM provider error with classification metadata. +type FailoverError struct { + Reason FailoverReason + Provider string + Model string + Status int + Wrapped error +} + +func (e *FailoverError) Error() string { + return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v", + e.Reason, e.Provider, e.Model, e.Status, e.Wrapped) +} + +func (e *FailoverError) Unwrap() error { + return e.Wrapped +} + +// IsRetriable returns true if this error should trigger fallback to next candidate. +// Non-retriable: Format errors (bad request structure, image dimension/size). +func (e *FailoverError) IsRetriable() bool { + return e.Reason != FailoverFormat +} + +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string +} + type ToolDefinition struct { Type string `json:"type"` Function ToolFunctionDefinition `json:"function"`