diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go index e7691aa93..88c92a47d 100644 --- a/pkg/providers/error_classifier.go +++ b/pkg/providers/error_classifier.go @@ -2,8 +2,12 @@ package providers import ( "context" + "errors" + "io" + "net" "regexp" "strings" + "syscall" ) // Common patterns in Go HTTP error messages @@ -50,6 +54,30 @@ var ( substr("context deadline exceeded"), } + networkPatterns = []errorPattern{ + substr("connection reset"), + substr("reset by peer"), + substr("connection refused"), + substr("connection aborted"), + substr("broken pipe"), + substr("use of closed network connection"), + substr("network is unreachable"), + substr("host is unreachable"), + substr("no such host"), + substr("temporary failure in name resolution"), + substr("server misbehaving"), + substr("read tcp"), + substr("write tcp"), + substr("dial tcp"), + substr("tls:"), + substr("x509:"), + substr("certificate"), + substr("handshake"), + substr("unexpected eof"), + substr("read: eof"), + substr("write: eof"), + } + billingPatterns = []errorPattern{ rxp(`\b402\b`), substr("payment required"), @@ -134,6 +162,17 @@ func ClassifyError(err error, provider, model string) *FailoverError { msg := strings.ToLower(err.Error()) + // Concrete transport errors should continue the fallback chain even when + // providers do not expose a structured HTTP status. + if reason := classifyByErrorType(err); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Wrapped: err, + } + } + // Image dimension/size errors: non-retriable, non-fallback. if IsImageDimensionError(msg) || IsImageSizeError(msg) { return &FailoverError{ @@ -170,6 +209,41 @@ func ClassifyError(err error, provider, model string) *FailoverError { return nil } +// classifyByErrorType maps concrete transport-layer error types to a retryable +// fallback reason before message heuristics are applied. +func classifyByErrorType(err error) FailoverReason { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return FailoverNetwork + } + + for _, transportErr := range []error{ + syscall.ECONNRESET, + syscall.ECONNABORTED, + syscall.ECONNREFUSED, + syscall.ETIMEDOUT, + syscall.EHOSTUNREACH, + syscall.ENETUNREACH, + syscall.EPIPE, + } { + if errors.Is(err, transportErr) { + if transportErr == syscall.ETIMEDOUT { + return FailoverTimeout + } + return FailoverNetwork + } + } + + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return FailoverTimeout + } + return FailoverNetwork + } + + return "" +} + // classifyByStatus maps HTTP status codes to FailoverReason. func classifyByStatus(status int) FailoverReason { switch { @@ -204,6 +278,9 @@ func classifyByMessage(msg string) FailoverReason { if matchesAny(msg, timeoutPatterns) { return FailoverTimeout } + if matchesAny(msg, networkPatterns) { + return FailoverNetwork + } if matchesAny(msg, authPatterns) { return FailoverAuth } diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go index 46b180835..571fb3882 100644 --- a/pkg/providers/error_classifier_test.go +++ b/pkg/providers/error_classifier_test.go @@ -4,9 +4,22 @@ import ( "context" "errors" "fmt" + "io" + "net" + "net/url" + "syscall" "testing" ) +type stubNetError struct { + msg string + timeout bool +} + +func (e stubNetError) Error() string { return e.msg } +func (e stubNetError) Timeout() bool { return e.timeout } +func (e stubNetError) Temporary() bool { return false } + func TestClassifyError_Nil(t *testing.T) { result := ClassifyError(nil, "openai", "gpt-4") if result != nil { @@ -154,6 +167,129 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) { } } +func TestClassifyError_NetworkPatterns(t *testing.T) { + patterns := []string{ + `failed to send request: Post "https://example.com": tls: bad record MAC`, + "read tcp 10.20.0.1:61279->172.65.90.20:443: read: connection reset by peer", + "failed to send request: dial tcp 203.0.113.10:443: connect: connection refused", + "tls handshake failure", + "x509: certificate has expired or is not yet valid", + "read tcp 127.0.0.1:443: read: unexpected EOF", + "lookup api.example.com: no such host", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverNetwork { + t.Errorf("pattern %q: reason = %q, want network", msg, result.Reason) + } + } +} + +func TestClassifyError_NetworkTypes(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "wrapped EOF", + err: &url.Error{ + Op: "Post", + URL: "https://example.com", + Err: io.EOF, + }, + }, + { + name: "dns error", + err: &net.DNSError{ + Err: "no such host", + Name: "api.example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyError(tt.err, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Reason != FailoverNetwork { + t.Fatalf("reason = %q, want network", result.Reason) + } + }) + } +} + +func TestClassifyError_TimeoutNetworkTypes(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "wrapped syscall timeout", + err: fmt.Errorf("dial tcp: %w", syscall.ETIMEDOUT), + }, + { + name: "net error timeout", + err: &url.Error{ + Op: "Post", + URL: "https://example.com", + Err: stubNetError{msg: "i/o timeout", timeout: true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyError(tt.err, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Reason != FailoverTimeout { + t.Fatalf("reason = %q, want timeout", result.Reason) + } + }) + } +} + +func TestClassifyError_TimeoutPatternsWinOverNetworkContext(t *testing.T) { + patterns := []string{ + `failed to send request: Post "https://example.com": dial tcp 203.0.113.10:443: i/o timeout`, + `read tcp 10.20.0.1:61279->172.65.90.20:443: i/o timeout`, + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + +func TestClassifyError_NetworkPatternsWinOverAuthExpired(t *testing.T) { + err := errors.New( + `Post "https://example.com": tls: failed to verify certificate: x509: certificate has expired or is not yet valid`, + ) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Reason != FailoverNetwork { + t.Fatalf("reason = %q, want network", result.Reason) + } +} + func TestClassifyError_AuthPatterns(t *testing.T) { patterns := []string{ "invalid api key", @@ -286,6 +422,7 @@ func TestFailoverError_IsRetriable(t *testing.T) { {FailoverAuth, true}, {FailoverRateLimit, true}, {FailoverBilling, true}, + {FailoverNetwork, true}, {FailoverTimeout, true}, {FailoverOverloaded, true}, {FailoverFormat, false}, diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 54fb9b6ea..07cc01baa 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -268,6 +268,75 @@ func TestFallback_UnclassifiedError(t *testing.T) { } } +func assertFallbackErrorFallsBack( + t *testing.T, + primaryProvider string, + primaryModel string, + initialErr error, + successContent string, + expectedReason FailoverReason, +) { + t.Helper() + + ct := NewCooldownTracker() + fc := NewFallbackChain(ct, nil) + + candidates := []FallbackCandidate{ + makeCandidate(primaryProvider, primaryModel), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, initialErr + } + return &LLMResponse{Content: successContent, FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("expected fallback success, got error: %v", err) + } + if attempt != 2 { + t.Fatalf("attempt = %d, want 2", attempt) + } + if result.Provider != "anthropic" || result.Model != "claude" { + t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model) + } + if len(result.Attempts) != 1 { + t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts)) + } + if result.Attempts[0].Reason != expectedReason { + t.Fatalf("attempt reason = %q, want %s", result.Attempts[0].Reason, expectedReason) + } +} + +func TestFallback_NetworkErrorFallsBack(t *testing.T) { + assertFallbackErrorFallsBack( + t, + "minimax", + "minimax-m2.7", + errors.New( + `failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`, + ), + "fallback ok", + FailoverNetwork, + ) +} + +func TestFallback_TimeoutErrorFallsBack(t *testing.T) { + assertFallbackErrorFallsBack( + t, + "openai", + "gpt-4", + errors.New("failed to send request: Post \"https://example.com\": i/o timeout"), + "timeout fallback ok", + FailoverTimeout, + ) +} + func TestFallback_SuccessResetsCooldown(t *testing.T) { ct := NewCooldownTracker() fc := NewFallbackChain(ct, nil) diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f98ae9243..fae252d13 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -74,6 +74,7 @@ const ( FailoverAuth FailoverReason = "auth" FailoverRateLimit FailoverReason = "rate_limit" FailoverBilling FailoverReason = "billing" + FailoverNetwork FailoverReason = "network" FailoverTimeout FailoverReason = "timeout" FailoverFormat FailoverReason = "format" FailoverContextOverflow FailoverReason = "context_overflow"