mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(network): implement network error classification and fallback handling
This commit is contained in:
@@ -2,8 +2,12 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Common patterns in Go HTTP error messages
|
||||
@@ -50,6 +54,30 @@ var (
|
||||
substr("context deadline exceeded"),
|
||||
}
|
||||
|
||||
networkPatterns = []errorPattern{
|
||||
substr("connection reset"),
|
||||
substr("reset by peer"),
|
||||
substr("connection refused"),
|
||||
substr("connection aborted"),
|
||||
substr("broken pipe"),
|
||||
substr("use of closed network connection"),
|
||||
substr("network is unreachable"),
|
||||
substr("host is unreachable"),
|
||||
substr("no such host"),
|
||||
substr("temporary failure in name resolution"),
|
||||
substr("server misbehaving"),
|
||||
substr("read tcp"),
|
||||
substr("write tcp"),
|
||||
substr("dial tcp"),
|
||||
substr("tls:"),
|
||||
substr("x509:"),
|
||||
substr("certificate"),
|
||||
substr("handshake"),
|
||||
substr("unexpected eof"),
|
||||
substr("read: eof"),
|
||||
substr("write: eof"),
|
||||
}
|
||||
|
||||
billingPatterns = []errorPattern{
|
||||
rxp(`\b402\b`),
|
||||
substr("payment required"),
|
||||
@@ -134,6 +162,17 @@ func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// Concrete transport errors should continue the fallback chain even when
|
||||
// providers do not expose a structured HTTP status.
|
||||
if reason := classifyByErrorType(err); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Image dimension/size errors: non-retriable, non-fallback.
|
||||
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
|
||||
return &FailoverError{
|
||||
@@ -170,6 +209,35 @@ func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyByErrorType maps concrete transport-layer error types to a retryable
|
||||
// fallback reason before message heuristics are applied.
|
||||
func classifyByErrorType(err error) FailoverReason {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
|
||||
for _, transportErr := range []error{
|
||||
syscall.ECONNRESET,
|
||||
syscall.ECONNABORTED,
|
||||
syscall.ECONNREFUSED,
|
||||
syscall.ETIMEDOUT,
|
||||
syscall.EHOSTUNREACH,
|
||||
syscall.ENETUNREACH,
|
||||
syscall.EPIPE,
|
||||
} {
|
||||
if errors.Is(err, transportErr) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyByStatus maps HTTP status codes to FailoverReason.
|
||||
func classifyByStatus(status int) FailoverReason {
|
||||
switch {
|
||||
@@ -204,6 +272,9 @@ func classifyByMessage(msg string) FailoverReason {
|
||||
if matchesAny(msg, timeoutPatterns) {
|
||||
return FailoverTimeout
|
||||
}
|
||||
if matchesAny(msg, networkPatterns) {
|
||||
return FailoverNetwork
|
||||
}
|
||||
if matchesAny(msg, authPatterns) {
|
||||
return FailoverAuth
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -154,6 +157,78 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
`failed to send request: Post "https://example.com": tls: bad record MAC`,
|
||||
"read tcp 10.20.0.1:61279->172.65.90.20:443: read: connection reset by peer",
|
||||
"failed to send request: dial tcp 203.0.113.10:443: connect: connection refused",
|
||||
"tls handshake failure",
|
||||
"x509: certificate has expired or is not yet valid",
|
||||
"read tcp 127.0.0.1:443: read: unexpected EOF",
|
||||
"lookup api.example.com: no such host",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Errorf("pattern %q: reason = %q, want network", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "wrapped EOF",
|
||||
err: &url.Error{
|
||||
Op: "Post",
|
||||
URL: "https://example.com",
|
||||
Err: io.EOF,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dns error",
|
||||
err: &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: "api.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClassifyError(tt.err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Fatalf("reason = %q, want network", result.Reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_NetworkPatternsWinOverAuthExpired(t *testing.T) {
|
||||
err := errors.New(
|
||||
`Post "https://example.com": tls: failed to verify certificate: x509: certificate has expired or is not yet valid`,
|
||||
)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Reason != FailoverNetwork {
|
||||
t.Fatalf("reason = %q, want network", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_AuthPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"invalid api key",
|
||||
@@ -286,6 +361,7 @@ func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
{FailoverAuth, true},
|
||||
{FailoverRateLimit, true},
|
||||
{FailoverBilling, true},
|
||||
{FailoverNetwork, true},
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
|
||||
@@ -268,6 +268,44 @@ func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NetworkErrorFallsBack(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("minimax", "minimax-m2.7"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New(
|
||||
`failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`,
|
||||
)
|
||||
}
|
||||
return &LLMResponse{Content: "fallback ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback success, got error: %v", err)
|
||||
}
|
||||
if attempt != 2 {
|
||||
t.Fatalf("attempt = %d, want 2", attempt)
|
||||
}
|
||||
if result.Provider != "anthropic" || result.Model != "claude" {
|
||||
t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model)
|
||||
}
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts))
|
||||
}
|
||||
if result.Attempts[0].Reason != FailoverNetwork {
|
||||
t.Fatalf("attempt reason = %q, want network", result.Attempts[0].Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
@@ -74,6 +74,7 @@ const (
|
||||
FailoverAuth FailoverReason = "auth"
|
||||
FailoverRateLimit FailoverReason = "rate_limit"
|
||||
FailoverBilling FailoverReason = "billing"
|
||||
FailoverNetwork FailoverReason = "network"
|
||||
FailoverTimeout FailoverReason = "timeout"
|
||||
FailoverFormat FailoverReason = "format"
|
||||
FailoverContextOverflow FailoverReason = "context_overflow"
|
||||
|
||||
Reference in New Issue
Block a user