From ae9489360578810711171809592f731bc4087f2e Mon Sep 17 00:00:00 2001 From: Badgerbees Date: Thu, 26 Mar 2026 03:03:19 +0700 Subject: [PATCH] adding test units --- pkg/agent/loop_test.go | 109 +++++++++++++++++++++++++ pkg/providers/error_classifier_test.go | 25 ++++++ 2 files changed, 134 insertions(+) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 341d15a2f..2366b1277 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -3,6 +3,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -2817,3 +2818,111 @@ func TestFilterClientWebSearch_EmptyInput(t *testing.T) { t.Fatalf("len(result) = %d, want 0", len(result)) } } + +type overflowProvider struct { + calls int + lastMessages []providers.Message + chatFunc func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any) (*providers.LLMResponse, error) +} + +func (p *overflowProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.calls++ + p.lastMessages = append([]providers.Message(nil), messages...) + + if p.chatFunc != nil { + return p.chatFunc(ctx, messages, tools, model, opts) + } + + if p.calls == 1 { + return nil, errors.New("context_window_exceeded") + } + + return &providers.LLMResponse{ + Content: "Recovered from overflow", + }, nil +} + +func (p *overflowProvider) GetDefaultModel() string { + return "test-model" +} + +func TestProcessMessage_ContextOverflowRecovery(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + _ = cfg + + provider := &overflowProvider{} + al.registry = NewAgentRegistry(al.cfg, provider) + + sessionKey := "agent:main:test-session" + agent := al.GetRegistry().GetDefaultAgent() + + for i := 0; i < 5; i++ { + agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"}) + agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"}) + } + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "test", + ChatID: "chat1", + SenderID: "user1", + SessionKey: "test-session", + Content: "trigger recovery", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "Recovered from overflow" { + t.Fatalf("response = %q, want %q", response, "Recovered from overflow") + } + + if provider.calls != 2 { + t.Fatalf("expected 2 calls, got %d", provider.calls) + } +} + +func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + _ = cfg + + provider := &overflowProvider{} + al.registry = NewAgentRegistry(al.cfg, provider) + + recoveryMsg := "error: status 400: context_window_exceeded" + + provider.chatFunc = func( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, + ) (*providers.LLMResponse, error) { + if provider.calls == 1 { + return nil, errors.New(recoveryMsg) + } + return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil + } + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "test", + ChatID: "chat1", + SenderID: "user1", + Content: "hello", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if !strings.Contains(response, "Anthropic recovery success") { + t.Fatalf("response = %q, want success message", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 calls for retry, got %d", provider.calls) + } +} diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go index 67d9af62b..46b180835 100644 --- a/pkg/providers/error_classifier_test.go +++ b/pkg/providers/error_classifier_test.go @@ -221,6 +221,30 @@ func TestClassifyError_ImageDimensionError(t *testing.T) { } } +func TestClassifyError_ContextOverflowPatterns(t *testing.T) { + patterns := []string{ + "context_length_exceeded", + "context_window_exceeded", + "maximum context length", + "token limit", + "too many tokens", + "prompt is too long", + "request too large", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverContextOverflow { + t.Errorf("pattern %q: reason = %q, want context_overflow", msg, result.Reason) + } + } +} + func TestClassifyError_ImageSizeError(t *testing.T) { err := errors.New("image exceeds 20 mb limit") result := ClassifyError(err, "openai", "gpt-4o") @@ -265,6 +289,7 @@ func TestFailoverError_IsRetriable(t *testing.T) { {FailoverTimeout, true}, {FailoverOverloaded, true}, {FailoverFormat, false}, + {FailoverContextOverflow, false}, {FailoverUnknown, true}, }