mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2547 from lc6464/chore/issue-2538-network-fallback
feat(network): improve network error classification and fallback handling
This commit is contained in:
@@ -2,8 +2,12 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Common patterns in Go HTTP error messages
|
||||
@@ -50,6 +54,30 @@ var (
|
||||
substr("context deadline exceeded"),
|
||||
}
|
||||
|
||||
networkPatterns = []errorPattern{
|
||||
substr("connection reset"),
|
||||
substr("reset by peer"),
|
||||
substr("connection refused"),
|
||||
substr("connection aborted"),
|
||||
substr("broken pipe"),
|
||||
substr("use of closed network connection"),
|
||||
substr("network is unreachable"),
|
||||
substr("host is unreachable"),
|
||||
substr("no such host"),
|
||||
substr("temporary failure in name resolution"),
|
||||
substr("server misbehaving"),
|
||||
substr("read tcp"),
|
||||
substr("write tcp"),
|
||||
substr("dial tcp"),
|
||||
substr("tls:"),
|
||||
substr("x509:"),
|
||||
substr("certificate"),
|
||||
substr("handshake"),
|
||||
substr("unexpected eof"),
|
||||
substr("read: eof"),
|
||||
substr("write: eof"),
|
||||
}
|
||||
|
||||
billingPatterns = []errorPattern{
|
||||
rxp(`\b402\b`),
|
||||
substr("payment required"),
|
||||
@@ -134,6 +162,17 @@ func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// Concrete transport errors should continue the fallback chain even when
|
||||
// providers do not expose a structured HTTP status.
|
||||
if reason := classifyByErrorType(err); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Image dimension/size errors: non-retriable, non-fallback.
|
||||
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
|
||||
return &FailoverError{
|
||||
@@ -170,6 +209,41 @@ func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyByErrorType maps concrete transport-layer error types to a retryable
|
||||
// fallback reason before message heuristics are applied.
|
||||
func classifyByErrorType(err error) FailoverReason {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
|
||||
for _, transportErr := range []error{
|
||||
syscall.ECONNRESET,
|
||||
syscall.ECONNABORTED,
|
||||
syscall.ECONNREFUSED,
|
||||
syscall.ETIMEDOUT,
|
||||
syscall.EHOSTUNREACH,
|
||||
syscall.ENETUNREACH,
|
||||
syscall.EPIPE,
|
||||
} {
|
||||
if errors.Is(err, transportErr) {
|
||||
if transportErr == syscall.ETIMEDOUT {
|
||||
return FailoverTimeout
|
||||
}
|
||||
return FailoverNetwork
|
||||
}
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
return FailoverTimeout
|
||||
}
|
||||
return FailoverNetwork
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyByStatus maps HTTP status codes to FailoverReason.
|
||||
func classifyByStatus(status int) FailoverReason {
|
||||
switch {
|
||||
@@ -204,6 +278,9 @@ func classifyByMessage(msg string) FailoverReason {
|
||||
if matchesAny(msg, timeoutPatterns) {
|
||||
return FailoverTimeout
|
||||
}
|
||||
if matchesAny(msg, networkPatterns) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
if matchesAny(msg, authPatterns) {
|
||||
return FailoverAuth
|
||||
}
|
||||
|
||||
@@ -4,9 +4,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stubNetError struct {
|
||||
msg string
|
||||
timeout bool
|
||||
}
|
||||
|
||||
func (e stubNetError) Error() string { return e.msg }
|
||||
func (e stubNetError) Timeout() bool { return e.timeout }
|
||||
func (e stubNetError) Temporary() bool { return false }
|
||||
|
||||
func TestClassifyError_Nil(t *testing.T) {
|
||||
result := ClassifyError(nil, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
@@ -154,6 +167,129 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
`failed to send request: Post "https://example.com": tls: bad record MAC`,
|
||||
"read tcp 10.20.0.1:61279->172.65.90.20:443: read: connection reset by peer",
|
||||
"failed to send request: dial tcp 203.0.113.10:443: connect: connection refused",
|
||||
"tls handshake failure",
|
||||
"x509: certificate has expired or is not yet valid",
|
||||
"read tcp 127.0.0.1:443: read: unexpected EOF",
|
||||
"lookup api.example.com: no such host",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Errorf("pattern %q: reason = %q, want network", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "wrapped EOF",
|
||||
err: &url.Error{
|
||||
Op: "Post",
|
||||
URL: "https://example.com",
|
||||
Err: io.EOF,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dns error",
|
||||
err: &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: "api.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClassifyError(tt.err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Fatalf("reason = %q, want network", result.Reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_TimeoutNetworkTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "wrapped syscall timeout",
|
||||
err: fmt.Errorf("dial tcp: %w", syscall.ETIMEDOUT),
|
||||
},
|
||||
{
|
||||
name: "net error timeout",
|
||||
err: &url.Error{
|
||||
Op: "Post",
|
||||
URL: "https://example.com",
|
||||
Err: stubNetError{msg: "i/o timeout", timeout: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClassifyError(tt.err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Fatalf("reason = %q, want timeout", result.Reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_TimeoutPatternsWinOverNetworkContext(t *testing.T) {
|
||||
patterns := []string{
|
||||
`failed to send request: Post "https://example.com": dial tcp 203.0.113.10:443: i/o timeout`,
|
||||
`read tcp 10.20.0.1:61279->172.65.90.20:443: i/o timeout`,
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkPatternsWinOverAuthExpired(t *testing.T) {
|
||||
err := errors.New(
|
||||
`Post "https://example.com": tls: failed to verify certificate: x509: certificate has expired or is not yet valid`,
|
||||
)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Fatalf("reason = %q, want network", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_AuthPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"invalid api key",
|
||||
@@ -286,6 +422,7 @@ func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
{FailoverAuth, true},
|
||||
{FailoverRateLimit, true},
|
||||
{FailoverBilling, true},
|
||||
{FailoverNetwork, true},
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
|
||||
@@ -268,6 +268,75 @@ func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertFallbackErrorFallsBack(
|
||||
t *testing.T,
|
||||
primaryProvider string,
|
||||
primaryModel string,
|
||||
initialErr error,
|
||||
successContent string,
|
||||
expectedReason FailoverReason,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate(primaryProvider, primaryModel),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, initialErr
|
||||
}
|
||||
return &LLMResponse{Content: successContent, FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback success, got error: %v", err)
|
||||
}
|
||||
if attempt != 2 {
|
||||
t.Fatalf("attempt = %d, want 2", attempt)
|
||||
}
|
||||
if result.Provider != "anthropic" || result.Model != "claude" {
|
||||
t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model)
|
||||
}
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts))
|
||||
}
|
||||
if result.Attempts[0].Reason != expectedReason {
|
||||
t.Fatalf("attempt reason = %q, want %s", result.Attempts[0].Reason, expectedReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NetworkErrorFallsBack(t *testing.T) {
|
||||
assertFallbackErrorFallsBack(
|
||||
t,
|
||||
"minimax",
|
||||
"minimax-m2.7",
|
||||
errors.New(
|
||||
`failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`,
|
||||
),
|
||||
"fallback ok",
|
||||
FailoverNetwork,
|
||||
)
|
||||
}
|
||||
|
||||
func TestFallback_TimeoutErrorFallsBack(t *testing.T) {
|
||||
assertFallbackErrorFallsBack(
|
||||
t,
|
||||
"openai",
|
||||
"gpt-4",
|
||||
errors.New("failed to send request: Post \"https://example.com\": i/o timeout"),
|
||||
"timeout fallback ok",
|
||||
FailoverTimeout,
|
||||
)
|
||||
}
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
@@ -74,6 +74,7 @@ const (
|
||||
FailoverAuth FailoverReason = "auth"
|
||||
FailoverRateLimit FailoverReason = "rate_limit"
|
||||
FailoverBilling FailoverReason = "billing"
|
||||
FailoverNetwork FailoverReason = "network"
|
||||
FailoverTimeout FailoverReason = "timeout"
|
||||
FailoverFormat FailoverReason = "format"
|
||||
FailoverContextOverflow FailoverReason = "context_overflow"
|
||||
|
||||
Reference in New Issue
Block a user