adding test units

This commit is contained in:
Badgerbees
2026-03-26 03:03:19 +07:00
parent 97dec16769
commit ae94893605
2 changed files with 134 additions and 0 deletions
+109
View File
@@ -3,6 +3,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@@ -2817,3 +2818,111 @@ func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
t.Fatalf("len(result) = %d, want 0", len(result))
}
}
type overflowProvider struct {
calls int
lastMessages []providers.Message
chatFunc func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any) (*providers.LLMResponse, error)
}
func (p *overflowProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.calls++
p.lastMessages = append([]providers.Message(nil), messages...)
if p.chatFunc != nil {
return p.chatFunc(ctx, messages, tools, model, opts)
}
if p.calls == 1 {
return nil, errors.New("context_window_exceeded")
}
return &providers.LLMResponse{
Content: "Recovered from overflow",
}, nil
}
func (p *overflowProvider) GetDefaultModel() string {
return "test-model"
}
func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
_ = cfg
provider := &overflowProvider{}
al.registry = NewAgentRegistry(al.cfg, provider)
sessionKey := "agent:main:test-session"
agent := al.GetRegistry().GetDefaultAgent()
for i := 0; i < 5; i++ {
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
SessionKey: "test-session",
Content: "trigger recovery",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Recovered from overflow" {
t.Fatalf("response = %q, want %q", response, "Recovered from overflow")
}
if provider.calls != 2 {
t.Fatalf("expected 2 calls, got %d", provider.calls)
}
}
func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
_ = cfg
provider := &overflowProvider{}
al.registry = NewAgentRegistry(al.cfg, provider)
recoveryMsg := "error: status 400: context_window_exceeded"
provider.chatFunc = func(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
if provider.calls == 1 {
return nil, errors.New(recoveryMsg)
}
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if !strings.Contains(response, "Anthropic recovery success") {
t.Fatalf("response = %q, want success message", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 calls for retry, got %d", provider.calls)
}
}
+25
View File
@@ -221,6 +221,30 @@ func TestClassifyError_ImageDimensionError(t *testing.T) {
}
}
func TestClassifyError_ContextOverflowPatterns(t *testing.T) {
patterns := []string{
"context_length_exceeded",
"context_window_exceeded",
"maximum context length",
"token limit",
"too many tokens",
"prompt is too long",
"request too large",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverContextOverflow {
t.Errorf("pattern %q: reason = %q, want context_overflow", msg, result.Reason)
}
}
}
func TestClassifyError_ImageSizeError(t *testing.T) {
err := errors.New("image exceeds 20 mb limit")
result := ClassifyError(err, "openai", "gpt-4o")
@@ -265,6 +289,7 @@ func TestFailoverError_IsRetriable(t *testing.T) {
{FailoverTimeout, true},
{FailoverOverloaded, true},
{FailoverFormat, false},
{FailoverContextOverflow, false},
{FailoverUnknown, true},
}