feat: add model fallback chain with error classification

Add 2-layer fallback system (text + image) with automatic candidate
resolution. Includes error classifier (~40 patterns), per-provider
cooldown (exponential backoff), and model reference parsing.

- FailoverError/FailoverReason types for structured error handling
- ErrorClassifier with rate_limit, billing, auth, timeout patterns
- FallbackChain with cooldown management and candidate rotation
- ModelRef parser for provider/model string format
- 128 tests, 95%+ coverage
This commit is contained in:
Leandro Barbosa
2026-02-13 12:12:12 -03:00
parent d83fb6e081
commit 6e7149509a
9 changed files with 2058 additions and 1 deletions
+207
View File
@@ -0,0 +1,207 @@
package providers
import (
"math"
"sync"
"time"
)
const (
defaultFailureWindow = 24 * time.Hour
)
// CooldownTracker manages per-provider cooldown state for the fallback chain.
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
type CooldownTracker struct {
mu sync.RWMutex
entries map[string]*cooldownEntry
failureWindow time.Duration
nowFunc func() time.Time // for testing
}
type cooldownEntry struct {
ErrorCount int
FailureCounts map[FailoverReason]int
CooldownEnd time.Time // standard cooldown expiry
DisabledUntil time.Time // billing-specific disable expiry
DisabledReason FailoverReason // reason for disable (billing)
LastFailure time.Time
}
// NewCooldownTracker creates a tracker with default 24h failure window.
func NewCooldownTracker() *CooldownTracker {
return &CooldownTracker{
entries: make(map[string]*cooldownEntry),
failureWindow: defaultFailureWindow,
nowFunc: time.Now,
}
}
// MarkFailure records a failure for a provider and sets appropriate cooldown.
// Resets error counts if last failure was more than failureWindow ago.
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
ct.mu.Lock()
defer ct.mu.Unlock()
now := ct.nowFunc()
entry := ct.getOrCreate(provider)
// 24h failure window reset: if no failure in failureWindow, reset counters.
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
}
entry.ErrorCount++
entry.FailureCounts[reason]++
entry.LastFailure = now
if reason == FailoverBilling {
billingCount := entry.FailureCounts[FailoverBilling]
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
entry.DisabledReason = FailoverBilling
} else {
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
}
}
// MarkSuccess resets all counters and cooldowns for a provider.
func (ct *CooldownTracker) MarkSuccess(provider string) {
ct.mu.Lock()
defer ct.mu.Unlock()
entry := ct.entries[provider]
if entry == nil {
return
}
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
entry.CooldownEnd = time.Time{}
entry.DisabledUntil = time.Time{}
entry.DisabledReason = ""
}
// IsAvailable returns true if the provider is not in cooldown or disabled.
func (ct *CooldownTracker) IsAvailable(provider string) bool {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return true
}
now := ct.nowFunc()
// Billing disable takes precedence (longer cooldown).
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
return false
}
// Standard cooldown.
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
return false
}
return true
}
// CooldownRemaining returns how long until the provider becomes available.
// Returns 0 if already available.
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
now := ct.nowFunc()
var remaining time.Duration
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
d := entry.DisabledUntil.Sub(now)
if d > remaining {
remaining = d
}
}
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
d := entry.CooldownEnd.Sub(now)
if d > remaining {
remaining = d
}
}
return remaining
}
// ErrorCount returns the current error count for a provider.
func (ct *CooldownTracker) ErrorCount(provider string) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.ErrorCount
}
// FailureCount returns the failure count for a specific reason.
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.FailureCounts[reason]
}
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
entry := ct.entries[provider]
if entry == nil {
entry = &cooldownEntry{
FailureCounts: make(map[FailoverReason]int),
}
ct.entries[provider] = entry
}
return entry
}
// calculateStandardCooldown computes standard exponential backoff.
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
//
// 1 error → 1 min
// 2 errors → 5 min
// 3 errors → 25 min
// 4+ errors → 1 hour (cap)
func calculateStandardCooldown(errorCount int) time.Duration {
n := max(1, errorCount)
exp := min(n-1, 3)
ms := 60_000 * int(math.Pow(5, float64(exp)))
ms = min(3_600_000, ms) // cap at 1 hour
return time.Duration(ms) * time.Millisecond
}
// calculateBillingCooldown computes billing-specific exponential backoff.
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
//
// 1 error → 5 hours
// 2 errors → 10 hours
// 3 errors → 20 hours
// 4+ errors → 24 hours (cap)
func calculateBillingCooldown(billingErrorCount int) time.Duration {
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
n := max(1, billingErrorCount)
exp := min(n-1, 10)
raw := float64(baseMs) * math.Pow(2, float64(exp))
ms := int(math.Min(float64(maxMs), raw))
return time.Duration(ms) * time.Millisecond
}
+269
View File
@@ -0,0 +1,269 @@
package providers
import (
"sync"
"testing"
"time"
)
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
current := now
ct := NewCooldownTracker()
ct.nowFunc = func() time.Time { return current }
return ct, &current
}
func TestCooldown_InitiallyAvailable(t *testing.T) {
ct := NewCooldownTracker()
if !ct.IsAvailable("openai") {
t.Error("new provider should be available")
}
if ct.ErrorCount("openai") != 0 {
t.Error("new provider should have 0 errors")
}
}
func TestCooldown_StandardEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st error → 1 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown after 1st error")
}
// Advance 61 seconds → available
*current = now.Add(61 * time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 1 min cooldown")
}
// 2nd error → 5 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(61*time.Second + 4*time.Minute)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown (5 min) after 2nd error")
}
*current = now.Add(61*time.Second + 6*time.Minute)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5 min cooldown")
}
}
func TestCooldown_StandardCap(t *testing.T) {
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
expected := []time.Duration{
1 * time.Minute,
5 * time.Minute,
25 * time.Minute,
1 * time.Hour,
1 * time.Hour,
}
for i, want := range expected {
got := calculateStandardCooldown(i + 1)
if got != want {
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_BillingEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st billing error → 5h cooldown
ct.MarkFailure("openai", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("should be disabled after billing error")
}
// Advance 4h → still disabled
*current = now.Add(4 * time.Hour)
if ct.IsAvailable("openai") {
t.Error("should still be disabled (5h cooldown)")
}
// Advance 5h + 1s → available
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5h billing cooldown")
}
}
func TestCooldown_BillingCap(t *testing.T) {
expected := []time.Duration{
5 * time.Hour,
10 * time.Hour,
20 * time.Hour,
24 * time.Hour,
24 * time.Hour,
}
for i, want := range expected {
got := calculateBillingCooldown(i + 1)
if got != want {
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_SuccessReset(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
if ct.ErrorCount("openai") != 2 {
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
}
ct.MarkSuccess("openai")
if ct.ErrorCount("openai") != 0 {
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
}
if !ct.IsAvailable("openai") {
t.Error("should be available after success")
}
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
t.Error("failure counts should be reset after success")
}
if ct.FailureCount("openai", FailoverBilling) != 0 {
t.Error("billing failure count should be reset after success")
}
}
func TestCooldown_FailureWindowReset(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
for i := 0; i < 4; i++ {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
}
// Advance 25 hours (past 24h failure window)
*current = now.Add(25 * time.Hour)
// Next error should reset counters first, then increment to 1
ct.MarkFailure("openai", FailoverRateLimit)
if ct.ErrorCount("openai") != 1 {
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
}
}
func TestCooldown_PerReasonTracking(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
ct.MarkFailure("openai", FailoverAuth)
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
}
if ct.FailureCount("openai", FailoverBilling) != 1 {
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
}
if ct.FailureCount("openai", FailoverAuth) != 1 {
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
}
}
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// Standard cooldown (1 min) + billing disable (5h)
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
ct.MarkFailure("openai", FailoverBilling) // 5h disable
// After 2 min: standard cooldown expired but billing still active
*current = now.Add(2 * time.Minute)
if ct.IsAvailable("openai") {
t.Error("billing disable should take precedence over standard cooldown")
}
// After 5h + 1s: both expired
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after all cooldowns expire")
}
}
func TestCooldown_CooldownRemaining(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// No failures → 0 remaining
if ct.CooldownRemaining("openai") != 0 {
t.Error("expected 0 remaining for new provider")
}
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(30 * time.Second)
remaining := ct.CooldownRemaining("openai")
if remaining <= 0 || remaining > 1*time.Minute {
t.Errorf("remaining = %v, expected ~30s", remaining)
}
}
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
ct := NewCooldownTracker()
// Should not panic
ct.MarkSuccess("nonexistent")
if !ct.IsAvailable("nonexistent") {
t.Error("nonexistent provider should be available")
}
}
func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(3)
go func() {
defer wg.Done()
ct.MarkFailure("openai", FailoverRateLimit)
}()
go func() {
defer wg.Done()
ct.IsAvailable("openai")
}()
go func() {
defer wg.Done()
ct.MarkSuccess("openai")
}()
}
wg.Wait()
// If we got here without panic, concurrent access is safe
}
func TestCooldown_MultipleProviders(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("openai should be in cooldown")
}
if ct.IsAvailable("anthropic") {
t.Error("anthropic should be in cooldown")
}
// groq was never touched
if !ct.IsAvailable("groq") {
t.Error("groq should be available")
}
}
+253
View File
@@ -0,0 +1,253 @@
package providers
import (
"context"
"regexp"
"strings"
)
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
regex *regexp.Regexp
}
func substr(s string) errorPattern { return errorPattern{substring: s} }
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
var (
rateLimitPatterns = []errorPattern{
rxp(`rate[_ ]limit`),
substr("too many requests"),
substr("429"),
substr("exceeded your current quota"),
rxp(`exceeded.*quota`),
rxp(`resource has been exhausted`),
rxp(`resource.*exhausted`),
substr("resource_exhausted"),
substr("quota exceeded"),
substr("usage limit"),
}
overloadedPatterns = []errorPattern{
rxp(`overloaded_error`),
rxp(`"type"\s*:\s*"overloaded_error"`),
substr("overloaded"),
}
timeoutPatterns = []errorPattern{
substr("timeout"),
substr("timed out"),
substr("deadline exceeded"),
substr("context deadline exceeded"),
}
billingPatterns = []errorPattern{
rxp(`\b402\b`),
substr("payment required"),
substr("insufficient credits"),
substr("credit balance"),
substr("plans & billing"),
substr("insufficient balance"),
}
authPatterns = []errorPattern{
rxp(`invalid[_ ]?api[_ ]?key`),
substr("incorrect api key"),
substr("invalid token"),
substr("authentication"),
substr("re-authenticate"),
substr("oauth token refresh failed"),
substr("unauthorized"),
substr("forbidden"),
substr("access denied"),
substr("expired"),
substr("token has expired"),
rxp(`\b401\b`),
rxp(`\b403\b`),
substr("no credentials found"),
substr("no api key found"),
}
formatPatterns = []errorPattern{
substr("string should match pattern"),
substr("tool_use.id"),
substr("tool_use_id"),
substr("messages.1.content.1.tool_use.id"),
substr("invalid request format"),
}
imageDimensionPatterns = []errorPattern{
rxp(`image dimensions exceed max`),
}
imageSizePatterns = []errorPattern{
rxp(`image exceeds.*mb`),
}
// Transient HTTP status codes that map to timeout (server-side failures).
transientStatusCodes = map[int]bool{
500: true, 502: true, 503: true,
521: true, 522: true, 523: true, 524: true,
529: true,
}
)
// ClassifyError classifies an error into a FailoverError with reason.
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
func ClassifyError(err error, provider, model string) *FailoverError {
if err == nil {
return nil
}
// Context cancellation: user abort, never fallback.
if err == context.Canceled {
return nil
}
// Context deadline exceeded: treat as timeout, always fallback.
if err == context.DeadlineExceeded {
return &FailoverError{
Reason: FailoverTimeout,
Provider: provider,
Model: model,
Wrapped: err,
}
}
msg := strings.ToLower(err.Error())
// Image dimension/size errors: non-retriable, non-fallback.
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
return &FailoverError{
Reason: FailoverFormat,
Provider: provider,
Model: model,
Wrapped: err,
}
}
// Try HTTP status code extraction first.
if status := extractHTTPStatus(msg); status > 0 {
if reason := classifyByStatus(status); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Status: status,
Wrapped: err,
}
}
}
// Message pattern matching (priority order from OpenClaw).
if reason := classifyByMessage(msg); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Wrapped: err,
}
}
return nil
}
// classifyByStatus maps HTTP status codes to FailoverReason.
func classifyByStatus(status int) FailoverReason {
switch {
case status == 401 || status == 403:
return FailoverAuth
case status == 402:
return FailoverBilling
case status == 408:
return FailoverTimeout
case status == 429:
return FailoverRateLimit
case status == 400:
return FailoverFormat
case transientStatusCodes[status]:
return FailoverTimeout
}
return ""
}
// classifyByMessage matches error messages against patterns.
// Priority order matters (from OpenClaw classifyFailoverReason).
func classifyByMessage(msg string) FailoverReason {
if matchesAny(msg, rateLimitPatterns) {
return FailoverRateLimit
}
if matchesAny(msg, overloadedPatterns) {
return FailoverRateLimit // Overloaded treated as rate_limit
}
if matchesAny(msg, billingPatterns) {
return FailoverBilling
}
if matchesAny(msg, timeoutPatterns) {
return FailoverTimeout
}
if matchesAny(msg, authPatterns) {
return FailoverAuth
}
if matchesAny(msg, formatPatterns) {
return FailoverFormat
}
return ""
}
// extractHTTPStatus extracts an HTTP status code from an error message.
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
func extractHTTPStatus(msg string) int {
// Common patterns in Go HTTP error messages
patterns := []*regexp.Regexp{
regexp.MustCompile(`status[:\s]+(\d{3})`),
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
}
for _, p := range patterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
return 0
}
// IsImageDimensionError returns true if the message indicates an image dimension error.
func IsImageDimensionError(msg string) bool {
return matchesAny(msg, imageDimensionPatterns)
}
// IsImageSizeError returns true if the message indicates an image file size error.
func IsImageSizeError(msg string) bool {
return matchesAny(msg, imageSizePatterns)
}
// matchesAny checks if msg matches any of the patterns.
func matchesAny(msg string, patterns []errorPattern) bool {
for _, p := range patterns {
if p.regex != nil {
if p.regex.MatchString(msg) {
return true
}
} else if p.substring != "" {
if strings.Contains(msg, p.substring) {
return true
}
}
}
return false
}
// parseDigits converts a string of digits to an int.
func parseDigits(s string) int {
n := 0
for _, c := range s {
if c >= '0' && c <= '9' {
n = n*10 + int(c-'0')
}
}
return n
}
+337
View File
@@ -0,0 +1,337 @@
package providers
import (
"context"
"errors"
"fmt"
"testing"
)
func TestClassifyError_Nil(t *testing.T) {
result := ClassifyError(nil, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for nil error, got %+v", result)
}
}
func TestClassifyError_ContextCanceled(t *testing.T) {
result := ClassifyError(context.Canceled, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
}
}
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
if result == nil {
t.Fatal("expected non-nil for deadline exceeded")
}
if result.Reason != FailoverTimeout {
t.Errorf("reason = %q, want timeout", result.Reason)
}
}
func TestClassifyError_StatusCodes(t *testing.T) {
tests := []struct {
status int
reason FailoverReason
}{
{401, FailoverAuth},
{403, FailoverAuth},
{402, FailoverBilling},
{408, FailoverTimeout},
{429, FailoverRateLimit},
{400, FailoverFormat},
{500, FailoverTimeout},
{502, FailoverTimeout},
{503, FailoverTimeout},
{521, FailoverTimeout},
{522, FailoverTimeout},
{523, FailoverTimeout},
{524, FailoverTimeout},
{529, FailoverTimeout},
}
for _, tt := range tests {
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
result := ClassifyError(err, "test", "model")
if result == nil {
t.Errorf("status %d: expected non-nil", tt.status)
continue
}
if result.Reason != tt.reason {
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
}
}
}
func TestClassifyError_RateLimitPatterns(t *testing.T) {
patterns := []string{
"rate limit exceeded",
"rate_limit reached",
"too many requests",
"exceeded your current quota",
"resource has been exhausted",
"resource_exhausted",
"quota exceeded",
"usage limit reached",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_OverloadedPatterns(t *testing.T) {
patterns := []string{
"overloaded_error",
`{"type": "overloaded_error"}`,
"server is overloaded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
// Overloaded is treated as rate_limit
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_BillingPatterns(t *testing.T) {
patterns := []string{
"payment required",
"insufficient credits",
"credit balance too low",
"plans & billing page",
"insufficient balance",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverBilling {
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
}
}
}
func TestClassifyError_TimeoutPatterns(t *testing.T) {
patterns := []string{
"request timeout",
"connection timed out",
"deadline exceeded",
"context deadline exceeded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverTimeout {
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
}
}
}
func TestClassifyError_AuthPatterns(t *testing.T) {
patterns := []string{
"invalid api key",
"invalid_api_key",
"incorrect api key",
"invalid token",
"authentication failed",
"re-authenticate",
"oauth token refresh failed",
"unauthorized access",
"forbidden",
"access denied",
"expired",
"token has expired",
"no credentials found",
"no api key found",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverAuth {
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
}
}
}
func TestClassifyError_FormatPatterns(t *testing.T) {
patterns := []string{
"string should match pattern",
"tool_use.id is required",
"invalid tool_use_id",
"messages.1.content.1.tool_use.id must be valid",
"invalid request format",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverFormat {
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
}
}
}
func TestClassifyError_ImageDimensionError(t *testing.T) {
err := errors.New("image dimensions exceed max allowed 2048x2048")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image dimension error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
if result.IsRetriable() {
t.Error("image dimension error should not be retriable")
}
}
func TestClassifyError_ImageSizeError(t *testing.T) {
err := errors.New("image exceeds 20 mb limit")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image size error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
}
func TestClassifyError_UnknownError(t *testing.T) {
err := errors.New("some completely random error")
result := ClassifyError(err, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for unknown error, got %+v", result)
}
}
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
err := errors.New("rate limit exceeded")
result := ClassifyError(err, "my-provider", "my-model")
if result == nil {
t.Fatal("expected non-nil")
}
if result.Provider != "my-provider" {
t.Errorf("provider = %q, want my-provider", result.Provider)
}
if result.Model != "my-model" {
t.Errorf("model = %q, want my-model", result.Model)
}
}
func TestFailoverError_IsRetriable(t *testing.T) {
tests := []struct {
reason FailoverReason
retriable bool
}{
{FailoverAuth, true},
{FailoverRateLimit, true},
{FailoverBilling, true},
{FailoverTimeout, true},
{FailoverOverloaded, true},
{FailoverFormat, false},
{FailoverUnknown, true},
}
for _, tt := range tests {
fe := &FailoverError{Reason: tt.reason}
if fe.IsRetriable() != tt.retriable {
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
}
}
}
func TestFailoverError_ErrorString(t *testing.T) {
fe := &FailoverError{
Reason: FailoverRateLimit,
Provider: "openai",
Model: "gpt-4",
Status: 429,
Wrapped: errors.New("too many requests"),
}
s := fe.Error()
if s == "" {
t.Error("expected non-empty error string")
}
}
func TestFailoverError_Unwrap(t *testing.T) {
inner := errors.New("inner error")
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
if fe.Unwrap() != inner {
t.Error("Unwrap should return wrapped error")
}
}
func TestExtractHTTPStatus(t *testing.T) {
tests := []struct {
msg string
want int
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
{"HTTP/1.1 502 Bad Gateway", 502},
{"no status code here", 0},
{"random number 12345", 0},
}
for _, tt := range tests {
got := extractHTTPStatus(tt.msg)
if got != tt.want {
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
}
}
}
func TestIsImageDimensionError(t *testing.T) {
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
t.Error("should match image dimensions exceed max")
}
if IsImageDimensionError("normal error message") {
t.Error("should not match normal error")
}
}
func TestIsImageSizeError(t *testing.T) {
if !IsImageSizeError("image exceeds 20 mb") {
t.Error("should match image exceeds mb")
}
if IsImageSizeError("normal error message") {
t.Error("should not match normal error")
}
}
+283
View File
@@ -0,0 +1,283 @@
package providers
import (
"context"
"fmt"
"strings"
"time"
)
// FallbackChain orchestrates model fallback across multiple candidates.
type FallbackChain struct {
cooldown *CooldownTracker
}
// FallbackCandidate represents one model/provider to try.
type FallbackCandidate struct {
Provider string
Model string
}
// FallbackResult contains the successful response and metadata about all attempts.
type FallbackResult struct {
Response *LLMResponse
Provider string
Model string
Attempts []FallbackAttempt
}
// FallbackAttempt records one attempt in the fallback chain.
type FallbackAttempt struct {
Provider string
Model string
Error error
Reason FailoverReason
Duration time.Duration
Skipped bool // true if skipped due to cooldown
}
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
return &FallbackChain{cooldown: cooldown}
}
// ResolveCandidates parses model config into a deduplicated candidate list.
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
seen := make(map[string]bool)
var candidates []FallbackCandidate
addCandidate := func(raw string) {
ref := ParseModelRef(raw, defaultProvider)
if ref == nil {
return
}
key := ModelKey(ref.Provider, ref.Model)
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
})
}
// Primary first.
addCandidate(cfg.Primary)
// Then fallbacks.
for _, fb := range cfg.Fallbacks {
addCandidate(fb)
}
return candidates
}
// Execute runs the fallback chain for text/chat requests.
// It tries each candidate in order, respecting cooldowns and error classification.
//
// Behavior:
// - Candidates in cooldown are skipped (logged as skipped attempt).
// - context.Canceled aborts immediately (user abort, no fallback).
// - Non-retriable errors (format) abort immediately.
// - Retriable errors trigger fallback to next candidate.
// - Success marks provider as good (resets cooldown).
// - If all fail, returns aggregate error with all attempts.
func (fc *FallbackChain) Execute(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
// Check context before each attempt.
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
// Check cooldown.
if !fc.cooldown.IsAvailable(candidate.Provider) {
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Skipped: true,
Reason: FailoverRateLimit,
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
})
continue
}
// Execute the run function.
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
// Success.
fc.cooldown.MarkSuccess(candidate.Provider)
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
// Context cancellation: abort immediately, no fallback.
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Classify the error.
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
if failErr == nil {
// Unclassifiable error: do not fallback, return immediately.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
candidate.Provider, candidate.Model, err)
}
// Non-retriable error: abort immediately.
if !failErr.IsRetriable() {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
return nil, failErr
}
// Retriable error: mark failure and continue to next candidate.
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
// If this was the last candidate, return aggregate error.
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
// All candidates were skipped (all in cooldown).
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// ExecuteImage runs the fallback chain for image/vision requests.
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
// Image dimension/size errors abort immediately (non-retriable).
func (fc *FallbackChain) ExecuteImage(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("image fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Image dimension/size errors are non-retriable.
errMsg := strings.ToLower(err.Error())
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Reason: FailoverFormat,
Duration: elapsed,
})
return nil, &FailoverError{
Reason: FailoverFormat,
Provider: candidate.Provider,
Model: candidate.Model,
Wrapped: err,
}
}
// Any other error: record and try next.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
type FallbackExhaustedError struct {
Attempts []FallbackAttempt
}
func (e *FallbackExhaustedError) Error() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
for i, a := range e.Attempts {
if a.Skipped {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
} else {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
}
}
return sb.String()
}
+473
View File
@@ -0,0 +1,473 @@
package providers
import (
"context"
"errors"
"testing"
"time"
)
func makeCandidate(provider, model string) FallbackCandidate {
return FallbackCandidate{Provider: provider, Model: model}
}
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
}
}
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, err
}
}
func TestFallback_SingleCandidate_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "hello" {
t.Errorf("content = %q, want hello", result.Response.Content)
}
if result.Provider != "openai" || result.Model != "gpt-4" {
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
}
}
func TestFallback_SecondCandidateSuccess(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude-opus"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
if result.Response.Content != "from claude" {
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
}
if len(result.Attempts) != 1 {
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
}
}
func TestFallback_AllFail(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
makeCandidate("groq", "llama"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, errors.New("rate limit exceeded")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error when all candidates fail")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
}
if len(exhausted.Attempts) != 3 {
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
}
}
func TestFallback_ContextCanceled(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
ctx, cancel := context.WithCancel(context.Background())
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
cancel() // cancel context
return nil, context.Canceled
}
t.Error("should not reach second candidate after cancel")
return nil, nil
}
_, err := fc.Execute(ctx, candidates, run)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestFallback_NonRetriableError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("string should match pattern")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for non-retriable")
}
var fe *FailoverError
if !errors.As(err, &fe) {
t.Fatalf("expected FailoverError, got %T", err)
}
if fe.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", fe.Reason)
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
}
}
func TestFallback_CooldownSkip(t *testing.T) {
now := time.Now()
ct, _ := newTestTracker(now)
fc := NewFallbackChain(ct)
// Put openai in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
if provider == "openai" {
t.Error("should not call openai (in cooldown)")
}
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
// Should have 1 skipped attempt
skipped := 0
for _, a := range result.Attempts {
if a.Skipped {
skipped++
}
}
if skipped != 1 {
t.Errorf("skipped = %d, want 1", skipped)
}
}
func TestFallback_AllInCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
// Put all providers in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
_, err := fc.Execute(context.Background(), candidates,
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
t.Error("should not call any provider (all in cooldown)")
return nil, nil
})
if err == nil {
t.Fatal("expected error when all in cooldown")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Fatalf("expected FallbackExhaustedError, got %T", err)
}
}
func TestFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
func TestFallback_EmptyFallbacks(t *testing.T) {
// Single primary, no fallbacks: should work like direct call
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "ok" {
t.Error("expected success with single candidate")
}
}
func TestFallback_UnclassifiedError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("completely unknown internal error")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for unclassified error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
}
}
func TestFallback_SuccessResetsCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
}
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
_, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ct.IsAvailable("openai") {
t.Error("success should reset cooldown")
}
}
// --- Image Fallback Tests ---
func TestImageFallback_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "image result" {
t.Error("expected image result")
}
}
func TestImageFallback_DimensionError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image dimensions exceed max 4096x4096")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image dimension error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
}
}
func TestImageFallback_SizeError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image exceeds 20 mb")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image size error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
}
}
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude-sonnet"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
}
result, err := fc.ExecuteImage(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
}
func TestImageFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
// --- ResolveCandidates Tests ---
func TestResolveCandidates_Simple(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 3 {
t.Fatalf("candidates = %d, want 3", len(candidates))
}
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
}
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
}
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
}
}
func TestResolveCandidates_Deduplication(t *testing.T) {
cfg := ModelConfig{
Primary: "openai/gpt-4",
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "default")
if len(candidates) != 2 {
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
}
}
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: nil,
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
cfg := ModelConfig{
Primary: "",
Fallbacks: []string{"anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestFallbackExhaustedError_Message(t *testing.T) {
e := &FallbackExhaustedError{
Attempts: []FallbackAttempt{
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
{Provider: "anthropic", Model: "claude", Skipped: true},
},
}
msg := e.Error()
if msg == "" {
t.Error("expected non-empty error message")
}
}
+64
View File
@@ -0,0 +1,64 @@
package providers
import "strings"
// ModelRef represents a parsed model reference with provider and model name.
type ModelRef struct {
Provider string
Model string
}
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
// If no slash present, uses defaultProvider.
// Returns nil for empty input.
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if idx := strings.Index(raw, "/"); idx > 0 {
provider := NormalizeProvider(raw[:idx])
model := strings.TrimSpace(raw[idx+1:])
if model == "" {
return nil
}
return &ModelRef{Provider: provider, Model: model}
}
return &ModelRef{
Provider: NormalizeProvider(defaultProvider),
Model: raw,
}
}
// NormalizeProvider normalizes provider identifiers to canonical form.
func NormalizeProvider(provider string) string {
p := strings.ToLower(strings.TrimSpace(provider))
switch p {
case "z.ai", "z-ai":
return "zai"
case "opencode-zen":
return "opencode"
case "qwen":
return "qwen-portal"
case "kimi-code":
return "kimi-coding"
case "gpt":
return "openai"
case "claude":
return "anthropic"
case "glm":
return "zhipu"
case "google":
return "gemini"
}
return p
}
// ModelKey returns a canonical "provider/model" key for deduplication.
func ModelKey(provider, model string) string {
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
}
+125
View File
@@ -0,0 +1,125 @@
package providers
import "testing"
func TestParseModelRef_WithSlash(t *testing.T) {
ref := ParseModelRef("anthropic/claude-opus", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestParseModelRef_WithoutSlash(t *testing.T) {
ref := ParseModelRef("gpt-4", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai", ref.Provider)
}
if ref.Model != "gpt-4" {
t.Errorf("model = %q, want gpt-4", ref.Model)
}
}
func TestParseModelRef_Empty(t *testing.T) {
ref := ParseModelRef("", "openai")
if ref != nil {
t.Errorf("expected nil for empty string, got %+v", ref)
}
}
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
ref := ParseModelRef("openai/", "default")
if ref != nil {
t.Errorf("expected nil for empty model, got %+v", ref)
}
}
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestNormalizeProvider(t *testing.T) {
tests := []struct {
input string
want string
}{
{"OpenAI", "openai"},
{"ANTHROPIC", "anthropic"},
{"z.ai", "zai"},
{"z-ai", "zai"},
{"Z.AI", "zai"},
{"opencode-zen", "opencode"},
{"qwen", "qwen-portal"},
{"kimi-code", "kimi-coding"},
{"gpt", "openai"},
{"claude", "anthropic"},
{"glm", "zhipu"},
{"google", "gemini"},
{"groq", "groq"},
{"", ""},
}
for _, tt := range tests {
got := NormalizeProvider(tt.input)
if got != tt.want {
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestModelKey(t *testing.T) {
tests := []struct {
provider string
model string
want string
}{
{"openai", "gpt-4", "openai/gpt-4"},
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
{"claude", "sonnet", "anthropic/sonnet"},
{"z.ai", "Model-X", "zai/model-x"},
}
for _, tt := range tests {
got := ModelKey(tt.provider, tt.model)
if got != tt.want {
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
}
}
}
func TestParseModelRef_ProviderNormalization(t *testing.T) {
ref := ParseModelRef("Z.AI/model-x", "default")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "zai" {
t.Errorf("provider = %q, want zai", ref.Provider)
}
}
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
ref := ParseModelRef("gpt-4o", "GPT")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
}
}
+47 -1
View File
@@ -1,6 +1,9 @@
package providers
import "context"
import (
"context"
"fmt"
)
type ToolCall struct {
ID string `json:"id"`
@@ -40,6 +43,49 @@ type LLMProvider interface {
GetDefaultModel() string
}
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
const (
FailoverAuth FailoverReason = "auth"
FailoverRateLimit FailoverReason = "rate_limit"
FailoverBilling FailoverReason = "billing"
FailoverTimeout FailoverReason = "timeout"
FailoverFormat FailoverReason = "format"
FailoverOverloaded FailoverReason = "overloaded"
FailoverUnknown FailoverReason = "unknown"
)
// FailoverError wraps an LLM provider error with classification metadata.
type FailoverError struct {
Reason FailoverReason
Provider string
Model string
Status int
Wrapped error
}
func (e *FailoverError) Error() string {
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
}
func (e *FailoverError) Unwrap() error {
return e.Wrapped
}
// IsRetriable returns true if this error should trigger fallback to next candidate.
// Non-retriable: Format errors (bad request structure, image dimension/size).
func (e *FailoverError) IsRetriable() bool {
return e.Reason != FailoverFormat
}
// ModelConfig holds primary model and fallback list.
type ModelConfig struct {
Primary string
Fallbacks []string
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`