From 7824bc715f2c219403ccd38d77a155af41ad1dc8 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Tue, 14 Apr 2026 22:31:30 +0200 Subject: [PATCH] add test --- pkg/agent/loop_test.go | 129 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 9cca84b6b..183d65afb 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2565,6 +2565,135 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { } } +type visionUnsupportedMediaProvider struct { + calls int + mediaSeen []bool +} + +func (p *visionUnsupportedMediaProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.calls++ + + hasMedia := false + for _, msg := range messages { + for _, ref := range msg.Media { + if strings.TrimSpace(ref) != "" { + hasMedia = true + break + } + } + if hasMedia { + break + } + } + p.mediaSeen = append(p.mediaSeen, hasMedia) + + if hasMedia { + return nil, fmt.Errorf("API request failed: Status: 404 Body: {\"error\":{\"message\":\"No endpoints found that support image input\"}}") + } + + return &providers.LLMResponse{ + Content: "ok", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (p *visionUnsupportedMediaProvider) GetDefaultModel() string { + return "mock-fail-model" +} + +func TestAgentLoop_VisionUnsupportedErrorStripsSessionMedia(t *testing.T) { + workspace := t.TempDir() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 3, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &visionUnsupportedMediaProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + sessionKey := "agent:main:telegram:direct:user1" + + timeoutCtx, cancel := context.WithTimeout(context.Background(), responseTimeout) + defer cancel() + + resp, err := al.processMessage(timeoutCtx, testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + MessageID: "m1", + }, + Content: "describe this", + Media: []string{"data:image/png;base64,abc123"}, + SessionKey: sessionKey, + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if resp != "ok" { + t.Fatalf("response = %q, want %q", resp, "ok") + } + if provider.calls != 2 { + t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2) + } + if !slices.Equal(provider.mediaSeen, []bool{true, false}) { + t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false}) + } + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + history := agent.Sessions.GetHistory(sessionKey) + for i, msg := range history { + if len(msg.Media) > 0 { + t.Fatalf("history[%d].Media = %v, want no media after stripping", i, msg.Media) + } + } + + timeoutCtx2, cancel2 := context.WithTimeout(context.Background(), responseTimeout) + defer cancel2() + + resp2, err := al.processMessage(timeoutCtx2, testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + MessageID: "m2", + }, + Content: "hello again", + SessionKey: sessionKey, + })) + if err != nil { + t.Fatalf("processMessage() second call error = %v", err) + } + if resp2 != "ok" { + t.Fatalf("second response = %q, want %q", resp2, "ok") + } + if provider.calls != 3 { + t.Fatalf("calls after second turn = %d, want %d", provider.calls, 3) + } + if !slices.Equal(provider.mediaSeen, []bool{true, false, false}) { + t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false, false}) + } +} + func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil {