diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go index 8ce98f8ba..88c92a47d 100644 --- a/pkg/providers/error_classifier.go +++ b/pkg/providers/error_classifier.go @@ -226,12 +226,18 @@ func classifyByErrorType(err error) FailoverReason { syscall.EPIPE, } { if errors.Is(err, transportErr) { + if transportErr == syscall.ETIMEDOUT { + return FailoverTimeout + } return FailoverNetwork } } var netErr net.Error if errors.As(err, &netErr) { + if netErr.Timeout() { + return FailoverTimeout + } return FailoverNetwork } diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go index 01b118078..571fb3882 100644 --- a/pkg/providers/error_classifier_test.go +++ b/pkg/providers/error_classifier_test.go @@ -7,9 +7,19 @@ import ( "io" "net" "net/url" + "syscall" "testing" ) +type stubNetError struct { + msg string + timeout bool +} + +func (e stubNetError) Error() string { return e.msg } +func (e stubNetError) Timeout() bool { return e.timeout } +func (e stubNetError) Temporary() bool { return false } + func TestClassifyError_Nil(t *testing.T) { result := ClassifyError(nil, "openai", "gpt-4") if result != nil { @@ -216,6 +226,57 @@ func TestClassifyError_NetworkTypes(t *testing.T) { } } +func TestClassifyError_TimeoutNetworkTypes(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "wrapped syscall timeout", + err: fmt.Errorf("dial tcp: %w", syscall.ETIMEDOUT), + }, + { + name: "net error timeout", + err: &url.Error{ + Op: "Post", + URL: "https://example.com", + Err: stubNetError{msg: "i/o timeout", timeout: true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyError(tt.err, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Reason != FailoverTimeout { + t.Fatalf("reason = %q, want timeout", result.Reason) + } + }) + } +} + +func TestClassifyError_TimeoutPatternsWinOverNetworkContext(t *testing.T) { + patterns := []string{ + `failed to send request: Post "https://example.com": dial tcp 203.0.113.10:443: i/o timeout`, + `read tcp 10.20.0.1:61279->172.65.90.20:443: i/o timeout`, + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + func TestClassifyError_NetworkPatternsWinOverAuthExpired(t *testing.T) { err := errors.New( `Post "https://example.com": tls: failed to verify certificate: x509: certificate has expired or is not yet valid`, diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 2bc50ec50..7b0652683 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -306,6 +306,42 @@ func TestFallback_NetworkErrorFallsBack(t *testing.T) { } } +func TestFallback_TimeoutErrorFallsBack(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct, nil) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("failed to send request: Post \"https://example.com\": i/o timeout") + } + return &LLMResponse{Content: "timeout fallback ok", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("expected fallback success, got error: %v", err) + } + if attempt != 2 { + t.Fatalf("attempt = %d, want 2", attempt) + } + if result.Provider != "anthropic" || result.Model != "claude" { + t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model) + } + if len(result.Attempts) != 1 { + t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts)) + } + if result.Attempts[0].Reason != FailoverTimeout { + t.Fatalf("attempt reason = %q, want timeout", result.Attempts[0].Reason) + } +} + func TestFallback_SuccessResetsCooldown(t *testing.T) { ct := NewCooldownTracker() fc := NewFallbackChain(ct, nil)