Merge pull request #2016 from badgerbees/fix/context-overflow-errors

fix(providers): improve context overflow detection and classification
This commit is contained in:
Mauro
2026-03-25 21:58:53 +01:00
committed by GitHub
5 changed files with 156 additions and 8 deletions
+1
View File
@@ -1949,6 +1949,7 @@ turnLoop:
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "context_window") ||
strings.Contains(errMsg, "maximum context length") ||
strings.Contains(errMsg, "token limit") ||
strings.Contains(errMsg, "too many tokens") ||
+109
View File
@@ -3,6 +3,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@@ -2817,3 +2818,111 @@ func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
t.Fatalf("len(result) = %d, want 0", len(result))
}
}
type overflowProvider struct {
calls int
lastMessages []providers.Message
chatFunc func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any) (*providers.LLMResponse, error)
}
func (p *overflowProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.calls++
p.lastMessages = append([]providers.Message(nil), messages...)
if p.chatFunc != nil {
return p.chatFunc(ctx, messages, tools, model, opts)
}
if p.calls == 1 {
return nil, errors.New("context_window_exceeded")
}
return &providers.LLMResponse{
Content: "Recovered from overflow",
}, nil
}
func (p *overflowProvider) GetDefaultModel() string {
return "test-model"
}
func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
_ = cfg
provider := &overflowProvider{}
al.registry = NewAgentRegistry(al.cfg, provider)
sessionKey := "agent:main:test-session"
agent := al.GetRegistry().GetDefaultAgent()
for i := 0; i < 5; i++ {
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
SessionKey: "test-session",
Content: "trigger recovery",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Recovered from overflow" {
t.Fatalf("response = %q, want %q", response, "Recovered from overflow")
}
if provider.calls != 2 {
t.Fatalf("expected 2 calls, got %d", provider.calls)
}
}
func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
_ = cfg
provider := &overflowProvider{}
al.registry = NewAgentRegistry(al.cfg, provider)
recoveryMsg := "error: status 400: context_window_exceeded"
provider.chatFunc = func(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
if provider.calls == 1 {
return nil, errors.New(recoveryMsg)
}
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if !strings.Contains(response, "Anthropic recovery success") {
t.Fatalf("response = %q, want success message", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 calls for retry, got %d", provider.calls)
}
}
+12
View File
@@ -84,6 +84,15 @@ var (
substr("messages.1.content.1.tool_use.id"),
substr("invalid request format"),
}
contextOverflowPatterns = []errorPattern{
rxp(`context[_ ]?length[_ ]?exceeded`),
rxp(`context[_ ]?window[_ ]?exceeded`),
substr("maximum context length"),
substr("token limit"),
substr("too many tokens"),
substr("prompt is too long"),
substr("request too large"),
}
imageDimensionPatterns = []errorPattern{
rxp(`image dimensions exceed max`),
@@ -201,6 +210,9 @@ func classifyByMessage(msg string) FailoverReason {
if matchesAny(msg, formatPatterns) {
return FailoverFormat
}
if matchesAny(msg, contextOverflowPatterns) {
return FailoverContextOverflow
}
return ""
}
+25
View File
@@ -221,6 +221,30 @@ func TestClassifyError_ImageDimensionError(t *testing.T) {
}
}
func TestClassifyError_ContextOverflowPatterns(t *testing.T) {
patterns := []string{
"context_length_exceeded",
"context_window_exceeded",
"maximum context length",
"token limit",
"too many tokens",
"prompt is too long",
"request too large",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverContextOverflow {
t.Errorf("pattern %q: reason = %q, want context_overflow", msg, result.Reason)
}
}
}
func TestClassifyError_ImageSizeError(t *testing.T) {
err := errors.New("image exceeds 20 mb limit")
result := ClassifyError(err, "openai", "gpt-4o")
@@ -265,6 +289,7 @@ func TestFailoverError_IsRetriable(t *testing.T) {
{FailoverTimeout, true},
{FailoverOverloaded, true},
{FailoverFormat, false},
{FailoverContextOverflow, false},
{FailoverUnknown, true},
}
+9 -8
View File
@@ -71,13 +71,14 @@ type NativeSearchCapable interface {
type FailoverReason string
const (
FailoverAuth FailoverReason = "auth"
FailoverRateLimit FailoverReason = "rate_limit"
FailoverBilling FailoverReason = "billing"
FailoverTimeout FailoverReason = "timeout"
FailoverFormat FailoverReason = "format"
FailoverOverloaded FailoverReason = "overloaded"
FailoverUnknown FailoverReason = "unknown"
FailoverAuth FailoverReason = "auth"
FailoverRateLimit FailoverReason = "rate_limit"
FailoverBilling FailoverReason = "billing"
FailoverTimeout FailoverReason = "timeout"
FailoverFormat FailoverReason = "format"
FailoverContextOverflow FailoverReason = "context_overflow"
FailoverOverloaded FailoverReason = "overloaded"
FailoverUnknown FailoverReason = "unknown"
)
// FailoverError wraps an LLM provider error with classification metadata.
@@ -101,7 +102,7 @@ func (e *FailoverError) Unwrap() error {
// IsRetriable returns true if this error should trigger fallback to next candidate.
// Non-retriable: Format errors (bad request structure, image dimension/size).
func (e *FailoverError) IsRetriable() bool {
return e.Reason != FailoverFormat
return e.Reason != FailoverFormat && e.Reason != FailoverContextOverflow
}
// ModelConfig holds primary model and fallback list.