mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
adding test units
This commit is contained in:
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -2817,3 +2818,111 @@ func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
|
||||
t.Fatalf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
type overflowProvider struct {
|
||||
calls int
|
||||
lastMessages []providers.Message
|
||||
chatFunc func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any) (*providers.LLMResponse, error)
|
||||
}
|
||||
|
||||
func (p *overflowProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.calls++
|
||||
p.lastMessages = append([]providers.Message(nil), messages...)
|
||||
|
||||
if p.chatFunc != nil {
|
||||
return p.chatFunc(ctx, messages, tools, model, opts)
|
||||
}
|
||||
|
||||
if p.calls == 1 {
|
||||
return nil, errors.New("context_window_exceeded")
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "Recovered from overflow",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *overflowProvider) GetDefaultModel() string {
|
||||
return "test-model"
|
||||
}
|
||||
|
||||
func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
_ = cfg
|
||||
|
||||
provider := &overflowProvider{}
|
||||
al.registry = NewAgentRegistry(al.cfg, provider)
|
||||
|
||||
sessionKey := "agent:main:test-session"
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
SessionKey: "test-session",
|
||||
Content: "trigger recovery",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Recovered from overflow" {
|
||||
t.Fatalf("response = %q, want %q", response, "Recovered from overflow")
|
||||
}
|
||||
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
_ = cfg
|
||||
|
||||
provider := &overflowProvider{}
|
||||
al.registry = NewAgentRegistry(al.cfg, provider)
|
||||
|
||||
recoveryMsg := "error: status 400: context_window_exceeded"
|
||||
|
||||
provider.chatFunc = func(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if provider.calls == 1 {
|
||||
return nil, errors.New(recoveryMsg)
|
||||
}
|
||||
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if !strings.Contains(response, "Anthropic recovery success") {
|
||||
t.Fatalf("response = %q, want success message", response)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("expected 2 calls for retry, got %d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +221,30 @@ func TestClassifyError_ImageDimensionError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextOverflowPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"context_length_exceeded",
|
||||
"context_window_exceeded",
|
||||
"maximum context length",
|
||||
"token limit",
|
||||
"too many tokens",
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverContextOverflow {
|
||||
t.Errorf("pattern %q: reason = %q, want context_overflow", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageSizeError(t *testing.T) {
|
||||
err := errors.New("image exceeds 20 mb limit")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
@@ -265,6 +289,7 @@ func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
{FailoverContextOverflow, false},
|
||||
{FailoverUnknown, true},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user