mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
cb1e1a3595
* fix(feishu): fix image download with API fallback and post image support - Add Image.Get API fallback when MessageResource.Get fails (different permission scope: im:resource vs im:message:readonly) - Extract and download images from post (rich text) messages - Extract images from interactive card messages - Deduplicate post image keys across locales - Add comprehensive tests for new helpers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(media): add image path tags alongside base64 for LLM file access Images are still base64-encoded into msg.Media for multimodal LLMs, but now also get [image:path] tags injected into message content so the LLM knows the local file path for save/forward operations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(media): only auto-inject images for tool results, not user messages Channel-received images (role=user) now get path tags only, letting the LLM decide whether to view via load_image or just operate on the file. Tool result images (role=tool, e.g. load_image) are base64-encoded into a synthetic user message appended after the tool message, since many LLM APIs don't support image_url in tool messages. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(media): preserve tool-message ordering for multi-tool-call scenarios Move synthetic user message (carrying base64 tool images) to after the entire contiguous tool-message block instead of immediately after each tool message. This preserves the assistant→tool→tool ordering required by OpenAI-compatible APIs. Also fix load_image to use generic [image: photo] placeholder so injectPathTags can properly replace it with the actual path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(test): update load_image test for [image: photo] placeholder The test was checking ForLLM for the media:// ref, but load_image now emits the generic [image: photo] placeholder instead. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(media): match all channel image placeholders in injectPathTags Different channels emit different placeholder formats — Telegram/Feishu use [image: photo], WeCom/WeChat/Line use bare [image], QQ/Discord use [image: <filename>]. The previous string-match code only handled [image: photo], so for the other channels the path tag was appended as a duplicate, producing content like "[image] [image:/path]". Switch to per-type regex that matches all generic placeholder shapes while leaving path tags ([image:/path]) untouched. Also fixes the same issue for [audio], [video], [file] tags. Added test coverage for the various placeholder shapes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(media): skip path tag append for JSON content (Feishu cards/posts) When content is structured JSON (interactive cards, post messages), injectPathTags now skips the fallback append — only placeholder replacement is attempted. This prevents corrupting JSON payloads like {"schema":"2.0",...} with appended [image:/path] tags. Adds looksLikeJSON() helper and three test cases covering JSON objects, arrays, and an end-to-end resolveMediaRefs scenario with Feishu card content. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(media): prepend path tags for JSON content, narrow looksLikeJSON Two fixes from code review: 1. looksLikeJSON now only checks for '{' prefix (not '['), avoiding false positives on regular text like "[update] see attached". 2. For JSON content (Feishu cards/posts), path tags are prepended before the JSON instead of being silently dropped. This ensures the LLM can discover attached images via the path tag while the JSON payload stays valid for downstream parsing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
5584 lines
163 KiB
Go
5584 lines
163 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/channels"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/media"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
"github.com/sipeed/picoclaw/pkg/session"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
"github.com/sipeed/picoclaw/pkg/utils"
|
|
)
|
|
|
|
type fakeChannel struct{ id string }
|
|
|
|
func (f *fakeChannel) Name() string { return "fake" }
|
|
func (f *fakeChannel) Start(ctx context.Context) error { return nil }
|
|
func (f *fakeChannel) Stop(ctx context.Context) error { return nil }
|
|
func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
func (f *fakeChannel) IsRunning() bool { return true }
|
|
func (f *fakeChannel) IsAllowed(string) bool { return true }
|
|
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
|
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
|
|
|
type fakeMediaChannel struct {
|
|
fakeChannel
|
|
sentMessages []bus.OutboundMessage
|
|
sentMedia []bus.OutboundMediaMessage
|
|
}
|
|
|
|
func (f *fakeMediaChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
|
f.sentMessages = append(f.sentMessages, msg)
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
|
f.sentMedia = append(f.sentMedia, msg)
|
|
return nil, nil
|
|
}
|
|
|
|
func newStartedTestChannelManager(
|
|
t *testing.T,
|
|
msgBus *bus.MessageBus,
|
|
store media.MediaStore,
|
|
name string,
|
|
ch channels.Channel,
|
|
) *channels.Manager {
|
|
t.Helper()
|
|
|
|
cm, err := channels.NewManager(&config.Config{}, msgBus, store)
|
|
if err != nil {
|
|
t.Fatalf("NewManager() error = %v", err)
|
|
}
|
|
cm.RegisterChannel(name, ch)
|
|
if err := cm.StartAll(context.Background()); err != nil {
|
|
t.Fatalf("StartAll() error = %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if err := cm.StopAll(context.Background()); err != nil {
|
|
t.Fatalf("StopAll() error = %v", err)
|
|
}
|
|
})
|
|
return cm
|
|
}
|
|
|
|
type recordingProvider struct {
|
|
lastMessages []providers.Message
|
|
lastModel string
|
|
}
|
|
|
|
func (r *recordingProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
r.lastMessages = append([]providers.Message(nil), messages...)
|
|
r.lastModel = model
|
|
return &providers.LLMResponse{
|
|
Content: "Mock response",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (r *recordingProvider) GetDefaultModel() string {
|
|
return "mock-model"
|
|
}
|
|
|
|
type modelRewriteHook struct {
|
|
model string
|
|
}
|
|
|
|
func (h modelRewriteHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
next := req.Clone()
|
|
next.Model = h.model
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h modelRewriteHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func useTestSideQuestionProvider(al *AgentLoop, provider providers.LLMProvider) {
|
|
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
|
model := provider.GetDefaultModel()
|
|
if mc != nil {
|
|
if _, modelID := providers.ExtractProtocol(mc); modelID != "" {
|
|
model = modelID
|
|
}
|
|
}
|
|
return provider, model, nil
|
|
}
|
|
}
|
|
|
|
func newTestAgentLoop(
|
|
t *testing.T,
|
|
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
|
t.Helper()
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
cfg = &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
msgBus = bus.NewMessageBus()
|
|
provider = &mockProvider{}
|
|
al = NewAgentLoop(cfg, msgBus, provider)
|
|
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
|
}
|
|
|
|
func TestNewAgentLoop_RegistersWebSearchTool(t *testing.T) {
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = t.TempDir()
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
|
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
if _, ok := agent.Tools.Get("web_search"); !ok {
|
|
t.Fatal("expected web_search tool to be registered")
|
|
}
|
|
}
|
|
|
|
func TestNewAgentLoop_RegistersWebSearchTool_WhenExplicitProviderUnavailable(t *testing.T) {
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = t.TempDir()
|
|
cfg.Tools.Web.Provider = "brave"
|
|
cfg.Tools.Web.Brave.Enabled = true
|
|
cfg.Tools.Web.Sogou.Enabled = true
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
|
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
if _, ok := agent.Tools.Get("web_search"); !ok {
|
|
t.Fatal("expected web_search tool to fall back to auto provider selection")
|
|
}
|
|
}
|
|
|
|
func TestNewAgentLoop_DoesNotRegisterWebSearchTool_WhenNoReadyProviders(t *testing.T) {
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = t.TempDir()
|
|
cfg.Tools.Web.Provider = "brave"
|
|
cfg.Tools.Web.Brave.Enabled = true
|
|
cfg.Tools.Web.Sogou.Enabled = false
|
|
cfg.Tools.Web.DuckDuckGo.Enabled = false
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
|
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
if _, ok := agent.Tools.Get("web_search"); ok {
|
|
t.Fatal("expected web_search tool to be absent when no providers are ready")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "discord",
|
|
SenderID: "discord:123",
|
|
Sender: bus.SenderInfo{
|
|
DisplayName: "Alice",
|
|
},
|
|
ChatID: "group-1",
|
|
Content: "hello",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages")
|
|
}
|
|
|
|
systemPrompt := provider.lastMessages[0].Content
|
|
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
|
|
if !strings.Contains(systemPrompt, wantSender) {
|
|
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
|
|
}
|
|
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
|
|
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
skillDir := filepath.Join(tmpDir, "skills", "shell")
|
|
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
|
t.Fatalf("mkdir skill dir: %v", err)
|
|
}
|
|
if err := os.WriteFile(
|
|
filepath.Join(skillDir, "SKILL.md"),
|
|
[]byte("# shell\n\nPrefer concise shell commands and explain them briefly."),
|
|
0o644,
|
|
); err != nil {
|
|
t.Fatalf("write skill file: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/use shell explain how to list files",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages")
|
|
}
|
|
|
|
systemPrompt := provider.lastMessages[0].Content
|
|
if !strings.Contains(systemPrompt, "# Active Skills") {
|
|
t.Fatalf("system prompt missing active skills section:\n%s", systemPrompt)
|
|
}
|
|
if !strings.Contains(systemPrompt, "### Skill: shell") {
|
|
t.Fatalf("system prompt missing requested skill content:\n%s", systemPrompt)
|
|
}
|
|
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "explain how to list files" {
|
|
t.Fatalf("last provider message = %+v, want rewritten user message", lastMessage)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
// Add model list so isolated provider can resolve the model
|
|
ModelList: []*config.ModelConfig{
|
|
{ModelName: "test-model", Model: "openai/test-model"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
msg := bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/btw explain side effects",
|
|
}
|
|
route, _, err := al.resolveMessageRoute(msg)
|
|
if err != nil {
|
|
t.Fatalf("resolveMessageRoute() error = %v", err)
|
|
}
|
|
allocation := al.allocateRouteSession(route, msg)
|
|
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
|
|
initialHistory := []providers.Message{
|
|
{Role: "user", Content: "We decided to avoid global state."},
|
|
{Role: "assistant", Content: "Right, keep it request-scoped."},
|
|
}
|
|
defaultAgent.Sessions.SetHistory(sessionKey, initialHistory)
|
|
defaultAgent.Sessions.SetSummary(sessionKey, "The team decided to keep state request-scoped.")
|
|
|
|
response, err := al.processMessage(context.Background(), msg)
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages")
|
|
}
|
|
if len(provider.lastMessages) != 4 {
|
|
t.Fatalf("provider messages len = %d, want 4 (system + prior history + user)", len(provider.lastMessages))
|
|
}
|
|
|
|
if !reflect.DeepEqual(provider.lastMessages[1:3], initialHistory) {
|
|
t.Fatalf("provider history = %#v, want %#v", provider.lastMessages[1:3], initialHistory)
|
|
}
|
|
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "explain side effects" {
|
|
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
|
}
|
|
|
|
history := al.GetRegistry().GetDefaultAgent().Sessions.GetHistory(sessionKey)
|
|
if !reflect.DeepEqual(history, initialHistory) {
|
|
t.Fatalf("session history = %#v, want %#v", history, initialHistory)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandIncludesRequestContextAndMedia(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "discord",
|
|
SenderID: "discord:123",
|
|
Sender: bus.SenderInfo{
|
|
DisplayName: "Alice",
|
|
},
|
|
ChatID: "group-1",
|
|
Content: "/btw describe this image",
|
|
Media: []string{"media://image-1"},
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages")
|
|
}
|
|
|
|
systemPrompt := provider.lastMessages[0].Content
|
|
if !strings.Contains(systemPrompt, "## Current Session\nChannel: discord\nChat ID: group-1") {
|
|
t.Fatalf("system prompt missing current session context:\n%s", systemPrompt)
|
|
}
|
|
if !strings.Contains(systemPrompt, "## Current Sender\nCurrent sender: Alice (ID: discord:123)") {
|
|
t.Fatalf("system prompt missing current sender context:\n%s", systemPrompt)
|
|
}
|
|
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "describe this image" {
|
|
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
|
}
|
|
if !reflect.DeepEqual(lastMessage.Media, []string{"media://image-1"}) {
|
|
t.Fatalf("last provider media = %#v, want media ref", lastMessage.Media)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
// Add model list so isolated provider can resolve the model
|
|
ModelList: []*config.ModelConfig{
|
|
{ModelName: "test-model", Model: "openai/test-model"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
// Set up initial history for the main session
|
|
mainSessionKey := "telegram:123:chat-1"
|
|
initialHistory := []providers.Message{
|
|
{Role: "user", Content: "We decided to avoid global state."},
|
|
{Role: "assistant", Content: "Right, keep it request-scoped."},
|
|
}
|
|
defaultAgent.Sessions.SetHistory(mainSessionKey, initialHistory)
|
|
|
|
// Process a /btw command
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
SessionKey: mainSessionKey,
|
|
Content: "/btw explain isolation",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
|
|
// Verify the provider received the side question
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages for /btw command")
|
|
}
|
|
|
|
// Verify the question was stripped of /btw prefix
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "explain isolation" {
|
|
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
|
}
|
|
|
|
// Verify main session history was NOT modified
|
|
currentHistory := defaultAgent.Sessions.GetHistory(mainSessionKey)
|
|
if !reflect.DeepEqual(currentHistory, initialHistory) {
|
|
t.Fatalf("main session history was modified:\ngot %#v\nwant %#v", currentHistory, initialHistory)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
// Add model list so isolated provider can resolve the model
|
|
ModelList: []*config.ModelConfig{
|
|
{ModelName: "test-model", Model: "openai/test-model"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &visionUnsupportedMediaProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/btw describe this image",
|
|
Media: []string{"data:image/png;base64,abc123"},
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "ok" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
|
|
}
|
|
// Note: With isolated providers, each /btw creates a new provider instance,
|
|
// so we can't track calls across retries in the same way.
|
|
// The retry logic happens within askSideQuestion, creating separate isolated providers.
|
|
// For now, we just verify the command succeeds.
|
|
if provider.calls < 1 {
|
|
t.Fatalf("provider was not called for /btw command")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "lb-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{ModelName: "lb-model", Model: "openai/lb-model-a"},
|
|
{ModelName: "lb-model", Model: "openai/lb-model-b"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/btw explain load balancing",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
|
|
// Verify that /btw used the configured model from ModelList
|
|
// The provider should have been called with one of the lb-model variants
|
|
if provider.lastModel == "" {
|
|
t.Fatal("provider was not called for /btw command")
|
|
}
|
|
if !strings.HasPrefix(provider.lastModel, "lb-model") {
|
|
t.Fatalf("/btw used model %q, expected lb-model variant", provider.lastModel)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_BtwCommandHookModelBypassesFallbackCandidates(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "primary-model",
|
|
ModelFallbacks: []string{"fallback-model"},
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
useTestSideQuestionProvider(al, provider)
|
|
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "hook-model"})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/btw explain hook routing",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
|
}
|
|
if provider.lastModel != "hook-model" {
|
|
t.Fatalf("/btw model = %q, want hook-selected model", provider.lastModel)
|
|
}
|
|
}
|
|
|
|
func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
|
|
opts := processOptions{}
|
|
reply, handled := al.handleCommand(context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/use missing explain how to list files",
|
|
}, agent, &opts)
|
|
if !handled {
|
|
t.Fatal("expected /use with unknown skill to be handled")
|
|
}
|
|
if !strings.Contains(reply, "Unknown skill: missing") {
|
|
t.Fatalf("reply = %q, want unknown skill error", reply)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
skillDir := filepath.Join(tmpDir, "skills", "shell")
|
|
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
|
t.Fatalf("mkdir skill dir: %v", err)
|
|
}
|
|
if err := os.WriteFile(
|
|
filepath.Join(skillDir, "SKILL.md"),
|
|
[]byte("# shell\n\nPrefer concise shell commands and explain them briefly."),
|
|
0o644,
|
|
); err != nil {
|
|
t.Fatalf("write skill file: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &recordingProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "/use shell",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() arm error = %v", err)
|
|
}
|
|
if !strings.Contains(response, `Skill "shell" is armed for your next message.`) {
|
|
t.Fatalf("arm response = %q, want armed confirmation", response)
|
|
}
|
|
|
|
response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "telegram:123",
|
|
ChatID: "chat-1",
|
|
Content: "explain how to list files",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() follow-up error = %v", err)
|
|
}
|
|
if response != "Mock response" {
|
|
t.Fatalf("follow-up response = %q, want %q", response, "Mock response")
|
|
}
|
|
if len(provider.lastMessages) == 0 {
|
|
t.Fatal("provider did not receive any messages")
|
|
}
|
|
|
|
systemPrompt := provider.lastMessages[0].Content
|
|
if !strings.Contains(systemPrompt, "### Skill: shell") {
|
|
t.Fatalf("system prompt missing pending skill content:\n%s", systemPrompt)
|
|
}
|
|
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
|
if lastMessage.Role != "user" || lastMessage.Content != "explain how to list files" {
|
|
t.Fatalf("last provider message = %+v, want unchanged follow-up user message", lastMessage)
|
|
}
|
|
}
|
|
|
|
func TestApplyExplicitSkillCommand_ArmsSkillForNextMessage(t *testing.T) {
|
|
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
|
|
t.Fatalf("MkdirAll(skill) error = %v", err)
|
|
}
|
|
if err := os.WriteFile(
|
|
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
|
|
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
|
|
0o644,
|
|
); err != nil {
|
|
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
|
|
}
|
|
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
opts := &processOptions{SessionKey: "agent:main:test"}
|
|
matched, handled, reply := al.applyExplicitSkillCommand("/use finance-news", agent, opts)
|
|
if !matched {
|
|
t.Fatal("expected /use command to match")
|
|
}
|
|
if !handled {
|
|
t.Fatal("expected /use without inline message to be handled immediately")
|
|
}
|
|
if !strings.Contains(reply, `Skill "finance-news" is armed for your next message`) {
|
|
t.Fatalf("unexpected reply: %q", reply)
|
|
}
|
|
|
|
pending := al.takePendingSkills(opts.SessionKey)
|
|
if len(pending) != 1 || pending[0] != "finance-news" {
|
|
t.Fatalf("pending skills = %#v, want [finance-news]", pending)
|
|
}
|
|
}
|
|
|
|
func TestApplyExplicitSkillCommand_InlineMessageMutatesOptions(t *testing.T) {
|
|
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
|
|
t.Fatalf("MkdirAll(skill) error = %v", err)
|
|
}
|
|
if err := os.WriteFile(
|
|
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
|
|
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
|
|
0o644,
|
|
); err != nil {
|
|
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
|
|
}
|
|
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
opts := &processOptions{
|
|
SessionKey: "agent:main:test",
|
|
UserMessage: "/use finance-news dammi le ultime news",
|
|
}
|
|
matched, handled, reply := al.applyExplicitSkillCommand(opts.UserMessage, agent, opts)
|
|
if !matched {
|
|
t.Fatal("expected /use command to match")
|
|
}
|
|
if handled {
|
|
t.Fatal("expected /use with inline message to fall through into normal agent execution")
|
|
}
|
|
if reply != "" {
|
|
t.Fatalf("unexpected reply: %q", reply)
|
|
}
|
|
if opts.UserMessage != "dammi le ultime news" {
|
|
t.Fatalf("opts.UserMessage = %q, want %q", opts.UserMessage, "dammi le ultime news")
|
|
}
|
|
if len(opts.ForcedSkills) != 1 || opts.ForcedSkills[0] != "finance-news" {
|
|
t.Fatalf("opts.ForcedSkills = %#v, want [finance-news]", opts.ForcedSkills)
|
|
}
|
|
}
|
|
|
|
func TestRecordLastChannel(t *testing.T) {
|
|
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
testChannel := "test-channel"
|
|
if err := al.RecordLastChannel(testChannel); err != nil {
|
|
t.Fatalf("RecordLastChannel failed: %v", err)
|
|
}
|
|
if got := al.state.GetLastChannel(); got != testChannel {
|
|
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
|
|
}
|
|
al2 := NewAgentLoop(cfg, msgBus, provider)
|
|
if got := al2.state.GetLastChannel(); got != testChannel {
|
|
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
|
|
}
|
|
}
|
|
|
|
func TestRecordLastChatID(t *testing.T) {
|
|
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
testChatID := "test-chat-id-123"
|
|
if err := al.RecordLastChatID(testChatID); err != nil {
|
|
t.Fatalf("RecordLastChatID failed: %v", err)
|
|
}
|
|
if got := al.state.GetLastChatID(); got != testChatID {
|
|
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
|
|
}
|
|
al2 := NewAgentLoop(cfg, msgBus, provider)
|
|
if got := al2.state.GetLastChatID(); got != testChatID {
|
|
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
|
|
}
|
|
}
|
|
|
|
func TestNewAgentLoop_StateInitialized(t *testing.T) {
|
|
// Create temp workspace
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
// Create test config
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Create agent loop
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Verify state manager is initialized
|
|
if al.state == nil {
|
|
t.Error("Expected state manager to be initialized")
|
|
}
|
|
|
|
// Verify state directory was created
|
|
stateDir := filepath.Join(tmpDir, "state")
|
|
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
|
|
t.Error("Expected state directory to exist")
|
|
}
|
|
}
|
|
|
|
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
|
|
func TestToolRegistry_ToolRegistration(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Register a custom tool
|
|
customTool := &mockCustomTool{}
|
|
al.RegisterTool(customTool)
|
|
|
|
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
|
|
// (actual tool retrieval is tested in tools package tests)
|
|
info := al.GetStartupInfo()
|
|
toolsInfo := info["tools"].(map[string]any)
|
|
toolsList := toolsInfo["names"].([]string)
|
|
|
|
// Check that our custom tool name is in the list
|
|
found := slices.Contains(toolsList, "mock_custom")
|
|
if !found {
|
|
t.Error("Expected custom tool to be registered")
|
|
}
|
|
}
|
|
|
|
// TestToolContext_Updates verifies tool context helpers work correctly
|
|
func TestToolContext_Updates(t *testing.T) {
|
|
ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42")
|
|
|
|
if got := tools.ToolChannel(ctx); got != "telegram" {
|
|
t.Errorf("expected channel 'telegram', got %q", got)
|
|
}
|
|
if got := tools.ToolChatID(ctx); got != "chat-42" {
|
|
t.Errorf("expected chatID 'chat-42', got %q", got)
|
|
}
|
|
|
|
// Empty context returns empty strings
|
|
if got := tools.ToolChannel(context.Background()); got != "" {
|
|
t.Errorf("expected empty channel from bare context, got %q", got)
|
|
}
|
|
|
|
inboundCtx := tools.WithToolInboundContext(
|
|
context.Background(),
|
|
"telegram",
|
|
"chat-42",
|
|
"msg-123",
|
|
"msg-100",
|
|
)
|
|
if got := tools.ToolMessageID(inboundCtx); got != "msg-123" {
|
|
t.Errorf("expected messageID 'msg-123', got %q", got)
|
|
}
|
|
if got := tools.ToolReplyToMessageID(inboundCtx); got != "msg-100" {
|
|
t.Errorf("expected replyToMessageID 'msg-100', got %q", got)
|
|
}
|
|
}
|
|
|
|
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
|
|
func TestToolRegistry_GetDefinitions(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Register a test tool and verify it shows up in startup info
|
|
testTool := &mockCustomTool{}
|
|
al.RegisterTool(testTool)
|
|
|
|
info := al.GetStartupInfo()
|
|
toolsInfo := info["tools"].(map[string]any)
|
|
toolsList := toolsInfo["names"].([]string)
|
|
|
|
// Check that our custom tool name is in the list
|
|
found := slices.Contains(toolsList, "mock_custom")
|
|
if !found {
|
|
t.Error("Expected custom tool to be registered")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &handledMediaProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
store := media.NewFileMediaStore()
|
|
al.SetMediaStore(store)
|
|
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
|
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
|
|
|
imagePath := filepath.Join(tmpDir, "screen.png")
|
|
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
|
}
|
|
|
|
al.RegisterTool(&handledMediaTool{
|
|
store: store,
|
|
path: imagePath,
|
|
})
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "take a screenshot of the screen and send it to me",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "" {
|
|
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
|
|
}
|
|
if provider.calls != 1 {
|
|
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
|
|
}
|
|
if len(provider.toolCounts) != 1 {
|
|
t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts))
|
|
}
|
|
if provider.toolCounts[0] == 0 {
|
|
t.Fatal("expected tools to be available on the first LLM call")
|
|
}
|
|
|
|
if len(telegramChannel.sentMedia) != 1 {
|
|
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
|
}
|
|
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
|
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
|
}
|
|
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
|
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
|
}
|
|
|
|
select {
|
|
case extra := <-msgBus.OutboundMediaChan():
|
|
t.Fatalf("expected handled media to bypass async queue, got %+v", extra)
|
|
default:
|
|
}
|
|
|
|
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "take a screenshot of the screen and send it to me",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("resolveMessageRoute() error = %v", err)
|
|
}
|
|
sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "take a screenshot of the screen and send it to me",
|
|
})).SessionKey, "")
|
|
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
|
if len(history) == 0 {
|
|
t.Fatal("expected session history to be saved")
|
|
}
|
|
last := history[len(history)-1]
|
|
if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." {
|
|
t.Fatalf("expected handled assistant summary in history, got %+v", last)
|
|
}
|
|
if len(last.Attachments) != 1 {
|
|
t.Fatalf("expected handled assistant summary attachments in history, got %+v", last.Attachments)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &handledMediaWithSteeringProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
store := media.NewFileMediaStore()
|
|
al.SetMediaStore(store)
|
|
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
|
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
|
|
|
imagePath := filepath.Join(tmpDir, "screen-steering.png")
|
|
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
|
}
|
|
|
|
al.RegisterTool(&handledMediaWithSteeringTool{
|
|
store: store,
|
|
path: imagePath,
|
|
loop: al,
|
|
})
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "take a screenshot of the screen and send it to me",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Handled the queued steering message." {
|
|
t.Fatalf("response = %q, want queued steering response", response)
|
|
}
|
|
if provider.calls != 2 {
|
|
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
|
|
}
|
|
if len(telegramChannel.sentMedia) != 1 {
|
|
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
|
}
|
|
}
|
|
|
|
func TestRunAgentLoop_ResponseHandledToolPublishesForUserWhenSendResponseDisabled(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = tmpDir
|
|
cfg.Agents.Defaults.ModelName = "test-model"
|
|
cfg.Agents.Defaults.MaxTokens = 4096
|
|
cfg.Agents.Defaults.MaxToolIterations = 10
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &handledUserProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
store := media.NewFileMediaStore()
|
|
al.SetMediaStore(store)
|
|
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
|
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
|
al.RegisterTool(&handledUserTool{})
|
|
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
|
Dispatch: DispatchRequest{
|
|
SessionKey: "session-1",
|
|
UserMessage: "take a screenshot of the screen and send it to me",
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: defaultAgent.ID,
|
|
Channel: "telegram",
|
|
Dimensions: []string{"chat"},
|
|
Values: map[string]string{
|
|
"chat": "direct:chat1",
|
|
},
|
|
},
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
},
|
|
},
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop() error = %v", err)
|
|
}
|
|
if response != "" {
|
|
t.Fatalf("expected no final response when tool already handled delivery, got %q", response)
|
|
}
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for len(telegramChannel.sentMessages) == 0 && time.Now().Before(deadline) {
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
if len(telegramChannel.sentMessages) != 1 {
|
|
t.Fatalf("expected exactly 1 sent text message, got %d", len(telegramChannel.sentMessages))
|
|
}
|
|
if telegramChannel.sentMessages[0].Content != "Handled user output from tool." {
|
|
t.Fatalf("unexpected sent text message: %+v", telegramChannel.sentMessages[0])
|
|
}
|
|
if telegramChannel.sentMessages[0].AgentID != defaultAgent.ID {
|
|
t.Fatalf("sent text agent_id = %q, want %q", telegramChannel.sentMessages[0].AgentID, defaultAgent.ID)
|
|
}
|
|
if telegramChannel.sentMessages[0].SessionKey != "session-1" {
|
|
t.Fatalf("sent text session_key = %q, want session-1", telegramChannel.sentMessages[0].SessionKey)
|
|
}
|
|
if telegramChannel.sentMessages[0].Scope == nil ||
|
|
telegramChannel.sentMessages[0].Scope.Values["chat"] != "direct:chat1" {
|
|
t.Fatalf("unexpected sent text scope: %+v", telegramChannel.sentMessages[0].Scope)
|
|
}
|
|
}
|
|
|
|
func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) {
|
|
fields := map[string]any{}
|
|
|
|
appendEventContextFields(fields, &TurnContext{
|
|
Inbound: &bus.InboundContext{
|
|
Channel: "slack",
|
|
Account: "workspace-a",
|
|
ChatID: "C123",
|
|
ChatType: "channel",
|
|
TopicID: "thread-42",
|
|
SpaceType: "workspace",
|
|
SpaceID: "T001",
|
|
SenderID: "U123",
|
|
Mentioned: true,
|
|
},
|
|
Route: &routing.ResolvedRoute{
|
|
AgentID: "support",
|
|
Channel: "slack",
|
|
AccountID: "workspace-a",
|
|
MatchedBy: "default",
|
|
SessionPolicy: routing.SessionPolicy{
|
|
Dimensions: []string{"chat", "sender"},
|
|
IdentityLinks: map[string][]string{
|
|
"canonical-user": {"slack:U123"},
|
|
},
|
|
},
|
|
},
|
|
Scope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: "support",
|
|
Channel: "slack",
|
|
Account: "workspace-a",
|
|
Dimensions: []string{"chat", "sender"},
|
|
Values: map[string]string{
|
|
"chat": "channel:c123",
|
|
"sender": "u123",
|
|
},
|
|
},
|
|
})
|
|
|
|
if fields["inbound_channel"] != "slack" {
|
|
t.Fatalf("inbound_channel = %v, want slack", fields["inbound_channel"])
|
|
}
|
|
if fields["inbound_topic_id"] != "thread-42" {
|
|
t.Fatalf("inbound_topic_id = %v, want thread-42", fields["inbound_topic_id"])
|
|
}
|
|
if fields["route_matched_by"] != "default" {
|
|
t.Fatalf("route_matched_by = %v, want default", fields["route_matched_by"])
|
|
}
|
|
if fields["route_dimensions"] != "chat,sender" {
|
|
t.Fatalf("route_dimensions = %v, want chat,sender", fields["route_dimensions"])
|
|
}
|
|
if fields["route_identity_link_count"] != 1 {
|
|
t.Fatalf("route_identity_link_count = %v, want 1", fields["route_identity_link_count"])
|
|
}
|
|
if fields["scope_dimensions"] != "chat,sender" {
|
|
t.Fatalf("scope_dimensions = %v, want chat,sender", fields["scope_dimensions"])
|
|
}
|
|
if fields["scope_chat"] != "channel:c123" {
|
|
t.Fatalf("scope_chat = %v, want channel:c123", fields["scope_chat"])
|
|
}
|
|
if fields["scope_sender"] != "u123" {
|
|
t.Fatalf("scope_sender = %v, want u123", fields["scope_sender"])
|
|
}
|
|
}
|
|
|
|
func TestResolveMessageRoute_UsesInboundContextAccount(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
},
|
|
List: []config.AgentConfig{
|
|
{ID: "main", Default: true},
|
|
{ID: "work"},
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
|
|
|
|
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "slack",
|
|
Account: "workspace-a",
|
|
ChatID: "C123",
|
|
ChatType: "channel",
|
|
SenderID: "U123",
|
|
SpaceID: "T001",
|
|
SpaceType: "workspace",
|
|
},
|
|
Content: "hello",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("resolveMessageRoute() error = %v", err)
|
|
}
|
|
if route.AgentID != "main" {
|
|
t.Fatalf("AgentID = %q, want main", route.AgentID)
|
|
}
|
|
if route.MatchedBy != "default" {
|
|
t.Fatalf("MatchedBy = %q, want default", route.MatchedBy)
|
|
}
|
|
if route.AccountID != "workspace-a" {
|
|
t.Fatalf("AccountID = %q, want workspace-a", route.AccountID)
|
|
}
|
|
}
|
|
|
|
func TestResolveMessageRoute_UsesDispatchRulesInOrder(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
},
|
|
List: []config.AgentConfig{
|
|
{ID: "main", Default: true},
|
|
{ID: "support"},
|
|
{ID: "sales"},
|
|
},
|
|
Dispatch: &config.DispatchConfig{
|
|
Rules: []config.DispatchRule{
|
|
{
|
|
Name: "support-group",
|
|
Agent: "support",
|
|
When: config.DispatchSelector{
|
|
Channel: "telegram",
|
|
Chat: "group:-100123",
|
|
},
|
|
SessionDimensions: []string{"chat"},
|
|
},
|
|
{
|
|
Name: "vip-in-group",
|
|
Agent: "sales",
|
|
When: config.DispatchSelector{
|
|
Channel: "telegram",
|
|
Chat: "group:-100123",
|
|
Sender: "12345",
|
|
},
|
|
SessionDimensions: []string{"chat", "sender"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
|
|
|
|
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "-100123",
|
|
ChatType: "group",
|
|
SenderID: "12345",
|
|
},
|
|
Content: "hello",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("resolveMessageRoute() error = %v", err)
|
|
}
|
|
if route.AgentID != "support" {
|
|
t.Fatalf("AgentID = %q, want support", route.AgentID)
|
|
}
|
|
if route.MatchedBy != "dispatch.rule:support-group" {
|
|
t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy)
|
|
}
|
|
if got := route.SessionPolicy.Dimensions; len(got) != 1 || got[0] != "chat" {
|
|
t.Fatalf("SessionPolicy.Dimensions = %v, want [chat]", got)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = tmpDir
|
|
cfg.Agents.Defaults.ModelName = "test-model"
|
|
cfg.Agents.Defaults.MaxTokens = 4096
|
|
cfg.Agents.Defaults.MaxToolIterations = 10
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &artifactThenSendProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
store := media.NewFileMediaStore()
|
|
al.SetMediaStore(store)
|
|
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
|
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
|
|
|
mediaDir := media.TempDir()
|
|
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
|
}
|
|
imagePath := filepath.Join(mediaDir, "artifact-screen.png")
|
|
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
|
}
|
|
|
|
al.RegisterTool(&mediaArtifactTool{
|
|
store: store,
|
|
path: imagePath,
|
|
})
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "take a screenshot of the screen and send it to me",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "" {
|
|
t.Fatalf("expected no final response after send_file handled delivery, got %q", response)
|
|
}
|
|
if provider.calls != 2 {
|
|
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
|
|
}
|
|
|
|
if len(telegramChannel.sentMedia) != 1 {
|
|
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
|
}
|
|
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
|
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
|
}
|
|
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
|
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
|
}
|
|
|
|
select {
|
|
case extra := <-msgBus.OutboundMediaChan():
|
|
t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra)
|
|
default:
|
|
}
|
|
}
|
|
|
|
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
|
|
func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = tmpDir
|
|
cfg.Agents.Defaults.ModelName = "test-model"
|
|
cfg.Agents.Defaults.MaxTokens = 4096
|
|
cfg.Agents.Defaults.MaxToolIterations = 10
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
info := al.GetStartupInfo()
|
|
|
|
// Verify tools info exists
|
|
toolsInfo, ok := info["tools"]
|
|
if !ok {
|
|
t.Fatal("Expected 'tools' key in startup info")
|
|
}
|
|
|
|
toolsMap, ok := toolsInfo.(map[string]any)
|
|
if !ok {
|
|
t.Fatal("Expected 'tools' to be a map")
|
|
}
|
|
|
|
count, ok := toolsMap["count"]
|
|
if !ok {
|
|
t.Fatal("Expected 'count' in tools info")
|
|
}
|
|
|
|
// Should have default tools registered
|
|
if count.(int) == 0 {
|
|
t.Error("Expected at least some tools to be registered")
|
|
}
|
|
}
|
|
|
|
// TestAgentLoop_Stop verifies Stop() sets running to false
|
|
func TestAgentLoop_Stop(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Note: running is only set to true when Run() is called
|
|
// We can't test that without starting the event loop
|
|
// Instead, verify the Stop method can be called safely
|
|
al.Stop()
|
|
|
|
// Verify running is false (initial state or after Stop)
|
|
if al.running.Load() {
|
|
t.Error("Expected agent to be stopped (or never started)")
|
|
}
|
|
}
|
|
|
|
// Mock implementations for testing
|
|
|
|
type simpleMockProvider struct {
|
|
response string
|
|
}
|
|
|
|
func (m *simpleMockProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{
|
|
Content: m.response,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *simpleMockProvider) GetDefaultModel() string {
|
|
return "mock-model"
|
|
}
|
|
|
|
type reasoningContentProvider struct {
|
|
response string
|
|
reasoningContent string
|
|
}
|
|
|
|
func (m *reasoningContentProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{
|
|
Content: m.response,
|
|
ReasoningContent: m.reasoningContent,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *reasoningContentProvider) GetDefaultModel() string {
|
|
return "reasoning-content-model"
|
|
}
|
|
|
|
type countingMockProvider struct {
|
|
response string
|
|
calls int
|
|
}
|
|
|
|
func (m *countingMockProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
return &providers.LLMResponse{
|
|
Content: m.response,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *countingMockProvider) GetDefaultModel() string {
|
|
return "counting-mock-model"
|
|
}
|
|
|
|
type handledMediaProvider struct {
|
|
calls int
|
|
toolCounts []int
|
|
}
|
|
|
|
func (m *handledMediaProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
m.toolCounts = append(m.toolCounts, len(tools))
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Taking the screenshot now.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_handled_media",
|
|
Type: "function",
|
|
Name: "handled_media_tool",
|
|
Arguments: map[string]any{},
|
|
}},
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{}, nil
|
|
}
|
|
|
|
func (m *handledMediaProvider) GetDefaultModel() string {
|
|
return "handled-media-model"
|
|
}
|
|
|
|
type handledUserProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *handledUserProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Delivering the result now.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_handled_user",
|
|
Type: "function",
|
|
Name: "handled_user_tool",
|
|
Arguments: map[string]any{},
|
|
}},
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{}, nil
|
|
}
|
|
|
|
func (m *handledUserProvider) GetDefaultModel() string {
|
|
return "handled-user-model"
|
|
}
|
|
|
|
type messageToolProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *messageToolProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_message",
|
|
Type: "function",
|
|
Name: "message",
|
|
Arguments: map[string]any{"content": "direct tool message"},
|
|
}},
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{}, nil
|
|
}
|
|
|
|
func (m *messageToolProvider) GetDefaultModel() string {
|
|
return "message-tool-model"
|
|
}
|
|
|
|
type reasoningVisibleToolProvider struct {
|
|
filePath string
|
|
calls int
|
|
}
|
|
|
|
func (m *reasoningVisibleToolProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "I'll inspect that file now.",
|
|
ReasoningContent: "Read the file before answering.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_read_file",
|
|
Type: "function",
|
|
Name: "read_file",
|
|
Arguments: map[string]any{"path": m.filePath},
|
|
}},
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{Content: "DONE"}, nil
|
|
}
|
|
|
|
func (m *reasoningVisibleToolProvider) GetDefaultModel() string {
|
|
return "reasoning-visible-tool-model"
|
|
}
|
|
|
|
type artifactThenSendProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *artifactThenSendProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Taking the screenshot now.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_artifact_media",
|
|
Type: "function",
|
|
Name: "media_artifact_tool",
|
|
Arguments: map[string]any{},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
var artifactPath string
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role != "tool" {
|
|
continue
|
|
}
|
|
for _, prefix := range []string{"[image:", "[file:", "[audio:", "[video:"} {
|
|
start := strings.Index(messages[i].Content, prefix)
|
|
if start < 0 {
|
|
continue
|
|
}
|
|
rest := messages[i].Content[start+len(prefix):]
|
|
end := strings.Index(rest, "]")
|
|
if end < 0 {
|
|
continue
|
|
}
|
|
artifactPath = rest[:end]
|
|
break
|
|
}
|
|
if artifactPath != "" {
|
|
break
|
|
}
|
|
}
|
|
if artifactPath == "" {
|
|
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_send_file",
|
|
Type: "function",
|
|
Name: "send_file",
|
|
Arguments: map[string]any{"path": artifactPath},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
func (m *artifactThenSendProvider) GetDefaultModel() string {
|
|
return "artifact-then-send-model"
|
|
}
|
|
|
|
type toolFeedbackProvider struct {
|
|
filePath string
|
|
calls int
|
|
}
|
|
|
|
func (m *toolFeedbackProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_heartbeat_read_file",
|
|
Type: "function",
|
|
Name: "read_file",
|
|
Arguments: map[string]any{"path": m.filePath},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "HEARTBEAT_OK",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *toolFeedbackProvider) GetDefaultModel() string {
|
|
return "heartbeat-tool-feedback-model"
|
|
}
|
|
|
|
type toolFeedbackReasoningProvider struct {
|
|
filePath string
|
|
calls int
|
|
}
|
|
|
|
func (m *toolFeedbackReasoningProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ReasoningContent: "Read README.md first to confirm the context that needs to be changed.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_reasoning_read_file",
|
|
Type: "function",
|
|
Name: "read_file",
|
|
Arguments: map[string]any{"path": m.filePath},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "DONE",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *toolFeedbackReasoningProvider) GetDefaultModel() string {
|
|
return "tool-feedback-reasoning-model"
|
|
}
|
|
|
|
func TestToolFeedbackExplanationFromResponse_UsesCurrentContentFirst(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
Content: "Read README.md first",
|
|
ReasoningContent: "current reasoning fallback",
|
|
}
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "check file"},
|
|
{Role: "assistant", Content: "Previous turn explanation"},
|
|
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
|
}
|
|
|
|
got := toolFeedbackExplanationFromResponse(response, messages)
|
|
if got != "Read README.md first" {
|
|
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want current content", got)
|
|
}
|
|
}
|
|
|
|
func TestSideQuestionResponseContent_FallsBackWhenContentIsWhitespace(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
Content: " \n\t ",
|
|
ReasoningContent: "reasoning fallback",
|
|
}
|
|
|
|
if got := sideQuestionResponseContent(response); got != "reasoning fallback" {
|
|
t.Fatalf("sideQuestionResponseContent() = %q, want %q", got, "reasoning fallback")
|
|
}
|
|
}
|
|
|
|
func TestResponseReasoningContent_FallsBackWhenReasoningIsWhitespace(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
Reasoning: " \n\t ",
|
|
ReasoningContent: "structured reasoning fallback",
|
|
}
|
|
|
|
if got := responseReasoningContent(response); got != "structured reasoning fallback" {
|
|
t.Fatalf("responseReasoningContent() = %q, want %q", got, "structured reasoning fallback")
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackExplanationFromResponse_UsesExplicitToolCallExtraContent(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_1",
|
|
Name: "read_file",
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: "Read README.md first to confirm the current project structure.",
|
|
},
|
|
}},
|
|
}
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "check file"},
|
|
{Role: "assistant", Content: ""},
|
|
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
|
}
|
|
|
|
got := toolFeedbackExplanationFromResponse(response, messages)
|
|
if got != "Read README.md first to confirm the current project structure." {
|
|
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want explicit tool feedback explanation", got)
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackExplanationForToolCall_PrefersToolSpecificExtraContent(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
Content: "Shared explanation",
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Name: "read_file",
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: "Read README.md first.",
|
|
},
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Name: "edit_file",
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: "Update config example after reading it.",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
got1 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], nil)
|
|
got2 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[1], nil)
|
|
if got1 != "Read README.md first." {
|
|
t.Fatalf("toolFeedbackExplanationForToolCall() first = %q, want tool-specific explanation", got1)
|
|
}
|
|
if got2 != "Update config example after reading it." {
|
|
t.Fatalf("toolFeedbackExplanationForToolCall() second = %q, want tool-specific explanation", got2)
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackExplanationForToolCall_DoesNotReuseAnotherToolCallExplanation(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Name: "read_file",
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Name: "edit_file",
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: "Update config example after reading it.",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "inspect the config and update the example"},
|
|
}
|
|
|
|
got := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], messages)
|
|
want := utils.ToolFeedbackContinuationHint + ": inspect the config and update the example"
|
|
if got != want {
|
|
t.Fatalf("toolFeedbackExplanationForToolCall() = %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackExplanationFromResponse_DoesNotUseReasoningContent(t *testing.T) {
|
|
response := &providers.LLMResponse{
|
|
Content: "",
|
|
ReasoningContent: "hidden reasoning should not be shown",
|
|
}
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "check file"},
|
|
{Role: "assistant", Content: "Previous turn explanation"},
|
|
{Role: "user", Content: "Inspect README.md and update the config example."},
|
|
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
|
}
|
|
|
|
got := toolFeedbackExplanationFromResponse(response, messages)
|
|
want := utils.ToolFeedbackContinuationHint + ": Inspect README.md and update the config example."
|
|
if got != want {
|
|
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want latest user content fallback", got)
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackExplanationForToolCall_DoesNotTruncateLongExplanation(t *testing.T) {
|
|
explanation := "Read README.md first to confirm the current project structure before editing the config example."
|
|
response := &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_1",
|
|
Name: "read_file",
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: explanation,
|
|
},
|
|
}},
|
|
}
|
|
|
|
got := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], nil)
|
|
if got != explanation {
|
|
t.Fatalf("toolFeedbackExplanationForToolCall() = %q, want full explanation", got)
|
|
}
|
|
}
|
|
|
|
func TestToolFeedbackArgsPreview_UsesJSONAndTruncates(t *testing.T) {
|
|
got := toolFeedbackArgsPreview(map[string]any{
|
|
"path": "README.md",
|
|
"limit": 42,
|
|
}, 128)
|
|
want := "{\n \"limit\": 42,\n \"path\": \"README.md\"\n}"
|
|
if got != want {
|
|
t.Fatalf("toolFeedbackArgsPreview() = %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
type picoInterleavedContentProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *picoInterleavedContentProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "intermediate model text",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_tool_limit_test",
|
|
Type: "function",
|
|
Name: "tool_limit_test_tool",
|
|
Arguments: map[string]any{"value": "x"},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "final model text",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *picoInterleavedContentProvider) GetDefaultModel() string {
|
|
return "pico-interleaved-content-model"
|
|
}
|
|
|
|
type picoDistinctToolCallContentProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *picoDistinctToolCallContentProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "intermediate model text",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_tool_limit_test",
|
|
Type: "function",
|
|
Name: "tool_limit_test_tool",
|
|
Arguments: map[string]any{"value": "x"},
|
|
ExtraContent: &providers.ExtraContent{
|
|
ToolFeedbackExplanation: "Read the file before replying.",
|
|
},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "final model text",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *picoDistinctToolCallContentProvider) GetDefaultModel() string {
|
|
return "pico-distinct-tool-call-content-model"
|
|
}
|
|
|
|
type toolLimitOnlyProvider struct{}
|
|
|
|
func (m *toolLimitOnlyProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_tool_limit_test",
|
|
Type: "function",
|
|
Name: "tool_limit_test_tool",
|
|
Arguments: map[string]any{"value": "x"},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
func (m *toolLimitOnlyProvider) GetDefaultModel() string {
|
|
return "tool-limit-only-model"
|
|
}
|
|
|
|
// mockCustomTool is a simple mock tool for registration testing
|
|
type mockCustomTool struct{}
|
|
|
|
func (m *mockCustomTool) Name() string {
|
|
return "mock_custom"
|
|
}
|
|
|
|
func (m *mockCustomTool) Description() string {
|
|
return "Mock custom tool for testing"
|
|
}
|
|
|
|
func (m *mockCustomTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
"additionalProperties": true,
|
|
}
|
|
}
|
|
|
|
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
return tools.SilentResult("Custom tool executed")
|
|
}
|
|
|
|
type handledMediaTool struct {
|
|
store media.MediaStore
|
|
path string
|
|
}
|
|
|
|
func (m *handledMediaTool) Name() string { return "handled_media_tool" }
|
|
func (m *handledMediaTool) Description() string {
|
|
return "Returns a media attachment and fully handles the user response"
|
|
}
|
|
|
|
func (m *handledMediaTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
ref, err := m.store.Store(m.path, media.MediaMeta{
|
|
Filename: filepath.Base(m.path),
|
|
ContentType: "image/png",
|
|
Source: "test:handled_media_tool",
|
|
}, "test:handled_media")
|
|
if err != nil {
|
|
return tools.ErrorResult(err.Error()).WithError(err)
|
|
}
|
|
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
|
}
|
|
|
|
type handledUserTool struct{}
|
|
|
|
func (m *handledUserTool) Name() string { return "handled_user_tool" }
|
|
func (m *handledUserTool) Description() string {
|
|
return "Returns a user-visible result and marks delivery as handled"
|
|
}
|
|
|
|
func (m *handledUserTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (m *handledUserTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
return tools.UserResult("Handled user output from tool.").WithResponseHandled()
|
|
}
|
|
|
|
type handledMediaWithSteeringProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *handledMediaWithSteeringProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Taking the screenshot now.",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_handled_media_steering",
|
|
Type: "function",
|
|
Name: "handled_media_with_steering_tool",
|
|
Arguments: map[string]any{},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
for _, msg := range messages {
|
|
if msg.Role == "user" && msg.Content == "what about this instead?" {
|
|
return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("provider did not receive queued steering message")
|
|
}
|
|
|
|
func (m *handledMediaWithSteeringProvider) GetDefaultModel() string {
|
|
return "handled-media-with-steering-model"
|
|
}
|
|
|
|
type handledMediaWithSteeringTool struct {
|
|
store media.MediaStore
|
|
path string
|
|
loop *AgentLoop
|
|
}
|
|
|
|
func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" }
|
|
func (m *handledMediaWithSteeringTool) Description() string {
|
|
return "Returns handled media and enqueues a steering message during execution"
|
|
}
|
|
|
|
func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
|
|
return tools.ErrorResult(err.Error()).WithError(err)
|
|
}
|
|
|
|
ref, err := m.store.Store(m.path, media.MediaMeta{
|
|
Filename: filepath.Base(m.path),
|
|
ContentType: "image/png",
|
|
Source: "test:handled_media_with_steering_tool",
|
|
}, "test:handled_media_with_steering")
|
|
if err != nil {
|
|
return tools.ErrorResult(err.Error()).WithError(err)
|
|
}
|
|
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
|
}
|
|
|
|
type mediaArtifactTool struct {
|
|
store media.MediaStore
|
|
path string
|
|
}
|
|
|
|
func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" }
|
|
func (m *mediaArtifactTool) Description() string {
|
|
return "Returns a media artifact that the agent can forward or save later"
|
|
}
|
|
|
|
func (m *mediaArtifactTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
ref, err := m.store.Store(m.path, media.MediaMeta{
|
|
Filename: filepath.Base(m.path),
|
|
ContentType: "image/png",
|
|
Source: "test:media_artifact_tool",
|
|
}, "test:media_artifact")
|
|
if err != nil {
|
|
return tools.ErrorResult(err.Error()).WithError(err)
|
|
}
|
|
return tools.MediaResult("Artifact created.", []string{ref})
|
|
}
|
|
|
|
type toolLimitTestTool struct{}
|
|
|
|
func (m *toolLimitTestTool) Name() string {
|
|
return "tool_limit_test_tool"
|
|
}
|
|
|
|
func (m *toolLimitTestTool) Description() string {
|
|
return "Tool used to exhaust the iteration budget in tests"
|
|
}
|
|
|
|
func (m *toolLimitTestTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"value": map[string]any{"type": "string"},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (m *toolLimitTestTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
return tools.SilentResult("tool limit test result")
|
|
}
|
|
|
|
// testHelper executes a message and returns the response
|
|
type testHelper struct {
|
|
al *AgentLoop
|
|
}
|
|
|
|
func newChatCompletionTestServer(
|
|
t *testing.T,
|
|
label string,
|
|
response string,
|
|
calls *int,
|
|
model *string,
|
|
) *httptest.Server {
|
|
t.Helper()
|
|
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/chat/completions" {
|
|
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
|
}
|
|
*calls = *calls + 1
|
|
defer r.Body.Close()
|
|
|
|
var req struct {
|
|
Model string `json:"model"`
|
|
}
|
|
decodeErr := json.NewDecoder(r.Body).Decode(&req)
|
|
if decodeErr != nil {
|
|
t.Fatalf("decode %s request: %v", label, decodeErr)
|
|
}
|
|
*model = req.Model
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
encodeErr := json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{
|
|
"message": map[string]any{"content": response},
|
|
"finish_reason": "stop",
|
|
},
|
|
},
|
|
})
|
|
if encodeErr != nil {
|
|
t.Fatalf("encode %s response: %v", label, encodeErr)
|
|
}
|
|
}))
|
|
}
|
|
|
|
func newStrictChatCompletionTestServer(
|
|
t *testing.T,
|
|
label string,
|
|
expectedModel string,
|
|
response string,
|
|
calls *int,
|
|
) *httptest.Server {
|
|
t.Helper()
|
|
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/chat/completions" {
|
|
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
|
}
|
|
*calls = *calls + 1
|
|
defer r.Body.Close()
|
|
|
|
var req struct {
|
|
Model string `json:"model"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode %s request: %v", label, err)
|
|
}
|
|
if req.Model != expectedModel {
|
|
t.Fatalf("%s server model = %q, want %q", label, req.Model, expectedModel)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{
|
|
"message": map[string]any{"content": response},
|
|
"finish_reason": "stop",
|
|
},
|
|
},
|
|
}); err != nil {
|
|
t.Fatalf("encode %s response: %v", label, err)
|
|
}
|
|
}))
|
|
}
|
|
|
|
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
|
// Use a short timeout to avoid hanging
|
|
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
|
defer cancel()
|
|
|
|
response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg))
|
|
if err != nil {
|
|
tb.Fatalf("processMessage failed: %v", err)
|
|
}
|
|
return response
|
|
}
|
|
|
|
func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage {
|
|
if msg.Context.Channel == "" &&
|
|
msg.Context.Account == "" &&
|
|
msg.Context.ChatID == "" &&
|
|
msg.Context.ChatType == "" &&
|
|
msg.Context.TopicID == "" &&
|
|
msg.Context.SpaceID == "" &&
|
|
msg.Context.SpaceType == "" &&
|
|
msg.Context.SenderID == "" &&
|
|
msg.Context.MessageID == "" &&
|
|
!msg.Context.Mentioned &&
|
|
msg.Context.ReplyToMessageID == "" &&
|
|
msg.Context.ReplyToSenderID == "" &&
|
|
len(msg.Context.ReplyHandles) == 0 &&
|
|
len(msg.Context.Raw) == 0 {
|
|
msg.Context = bus.InboundContext{
|
|
Channel: msg.Channel,
|
|
ChatID: msg.ChatID,
|
|
ChatType: "direct",
|
|
SenderID: msg.SenderID,
|
|
MessageID: msg.MessageID,
|
|
}
|
|
}
|
|
return bus.NormalizeInboundMessage(msg)
|
|
}
|
|
|
|
const responseTimeout = 3 * time.Second
|
|
|
|
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &simpleMockProvider{response: "ok"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
msg := bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
},
|
|
Content: "hello",
|
|
}
|
|
|
|
route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context)
|
|
sessionKey := al.allocateRouteSession(route, msg).SessionKey
|
|
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("No default agent found")
|
|
}
|
|
|
|
helper := testHelper{al: al}
|
|
_ = helper.executeAndGetResponse(t, context.Background(), msg)
|
|
|
|
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
|
if len(history) != 2 {
|
|
t.Fatalf("expected session history len=2, got %d", len(history))
|
|
}
|
|
if history[0].Role != "user" || history[0].Content != "hello" {
|
|
t.Fatalf("unexpected first message in session: %+v", history[0])
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"chat"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &countingMockProvider{response: "LLM reply"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
baseMsg := bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "whatsapp",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
},
|
|
}
|
|
|
|
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: baseMsg.Context.Channel,
|
|
ChatID: baseMsg.Context.ChatID,
|
|
ChatType: baseMsg.Context.ChatType,
|
|
SenderID: baseMsg.Context.SenderID,
|
|
},
|
|
Content: "/show channel",
|
|
})
|
|
if showResp != "Current Channel: whatsapp" {
|
|
t.Fatalf("unexpected /show reply: %q", showResp)
|
|
}
|
|
if provider.calls != 0 {
|
|
t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls)
|
|
}
|
|
|
|
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: baseMsg.Context.Channel,
|
|
ChatID: baseMsg.Context.ChatID,
|
|
ChatType: baseMsg.Context.ChatType,
|
|
SenderID: baseMsg.Context.SenderID,
|
|
},
|
|
Content: "/foo",
|
|
})
|
|
if fooResp != "LLM reply" {
|
|
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
|
}
|
|
if provider.calls != 1 {
|
|
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
|
|
}
|
|
|
|
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: baseMsg.Context.Channel,
|
|
ChatID: baseMsg.Context.ChatID,
|
|
ChatType: baseMsg.Context.ChatType,
|
|
SenderID: baseMsg.Context.SenderID,
|
|
},
|
|
Content: "/new",
|
|
})
|
|
if newResp != "LLM reply" {
|
|
t.Fatalf("unexpected /new reply: %q", newResp)
|
|
}
|
|
if provider.calls != 2 {
|
|
t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_MCPCommandsHandledWithoutLLMCall(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
deferred := true
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"chat"},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
MCP: config.MCPConfig{
|
|
ToolConfig: config.ToolConfig{Enabled: true},
|
|
Discovery: config.ToolDiscoveryConfig{Enabled: true},
|
|
Servers: map[string]config.MCPServerConfig{
|
|
"github": {
|
|
Enabled: true,
|
|
Deferred: &deferred,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &countingMockProvider{response: "LLM reply"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
baseContext := bus.InboundContext{
|
|
Channel: "whatsapp",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
}
|
|
|
|
listResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Context: baseContext,
|
|
Content: "/list mcp",
|
|
})
|
|
if !strings.Contains(listResp, "- `github`") || !strings.Contains(listResp, "Deferred: yes") {
|
|
t.Fatalf("unexpected /list mcp reply: %q", listResp)
|
|
}
|
|
if provider.calls != 0 {
|
|
t.Fatalf("LLM should not be called for /list mcp, calls=%d", provider.calls)
|
|
}
|
|
|
|
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Context: baseContext,
|
|
Content: "/show mcp github",
|
|
})
|
|
if showResp != "MCP server 'github' is configured but not connected" {
|
|
t.Fatalf("unexpected /show mcp reply: %q", showResp)
|
|
}
|
|
if provider.calls != 0 {
|
|
t.Fatalf("LLM should not be called for /show mcp, calls=%d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Provider: "openai",
|
|
ModelName: "local",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "local",
|
|
Model: "openai/local-model",
|
|
APIBase: "https://local.example.invalid/v1",
|
|
APIKeys: config.SimpleSecureStrings("test-key"),
|
|
},
|
|
{
|
|
ModelName: "deepseek",
|
|
Model: "openrouter/deepseek/deepseek-v3.2",
|
|
APIBase: "https://openrouter.ai/api/v1",
|
|
APIKeys: config.SimpleSecureStrings("test-key"),
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &countingMockProvider{response: "LLM reply"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "/switch model to deepseek",
|
|
})
|
|
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
|
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
|
}
|
|
|
|
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "/show model",
|
|
})
|
|
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
|
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
|
}
|
|
|
|
if provider.calls != 0 {
|
|
t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Provider: "openai",
|
|
ModelName: "local",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "local",
|
|
Model: "openai/local-model",
|
|
APIBase: "https://local.example.invalid/v1",
|
|
APIKeys: config.SimpleSecureStrings("test-key"),
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &countingMockProvider{response: "LLM reply"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "/switch model to missing",
|
|
})
|
|
if switchResp != `model "missing" not found in model_list or providers` {
|
|
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
|
}
|
|
|
|
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "/show model",
|
|
})
|
|
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
|
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
|
}
|
|
|
|
if provider.calls != 0 {
|
|
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
localCalls := 0
|
|
localModel := ""
|
|
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
|
|
defer localServer.Close()
|
|
|
|
remoteCalls := 0
|
|
remoteModel := ""
|
|
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
|
defer remoteServer.Close()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Provider: "openai",
|
|
ModelName: "local",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "local",
|
|
Model: "openai/Qwen3.5-35B-A3B",
|
|
APIBase: localServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("local-key"),
|
|
},
|
|
{
|
|
ModelName: "deepseek",
|
|
Model: "openrouter/deepseek/deepseek-v3.2",
|
|
APIBase: remoteServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("remote-key"),
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider, _, err := providers.CreateProvider(cfg)
|
|
if err != nil {
|
|
t.Fatalf("CreateProvider() error = %v", err)
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hello before switch",
|
|
})
|
|
if firstResp != "local reply" {
|
|
t.Fatalf("unexpected response before switch: %q", firstResp)
|
|
}
|
|
if localCalls != 1 {
|
|
t.Fatalf("local calls before switch = %d, want 1", localCalls)
|
|
}
|
|
if remoteCalls != 0 {
|
|
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
|
|
}
|
|
if localModel != "Qwen3.5-35B-A3B" {
|
|
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
|
|
}
|
|
|
|
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "/switch model to deepseek",
|
|
})
|
|
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
|
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
|
}
|
|
|
|
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hello after switch",
|
|
})
|
|
if secondResp != "remote reply" {
|
|
t.Fatalf("unexpected response after switch: %q", secondResp)
|
|
}
|
|
if localCalls != 1 {
|
|
t.Fatalf("local calls after switch = %d, want 1", localCalls)
|
|
}
|
|
if remoteCalls != 1 {
|
|
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
|
|
}
|
|
if remoteModel != "deepseek-v3.2" {
|
|
t.Fatalf(
|
|
"remote model after switch = %q, want %q",
|
|
remoteModel,
|
|
"deepseek-v3.2",
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
heavyCalls := 0
|
|
heavyServer := newStrictChatCompletionTestServer(
|
|
t,
|
|
"heavy",
|
|
"gemini-2.5-flash",
|
|
"heavy reply",
|
|
&heavyCalls,
|
|
)
|
|
defer heavyServer.Close()
|
|
|
|
lightCalls := 0
|
|
lightServer := newStrictChatCompletionTestServer(
|
|
t,
|
|
"light",
|
|
"qwen2.5:0.5b",
|
|
"light reply",
|
|
&lightCalls,
|
|
)
|
|
defer lightServer.Close()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "gemini-main",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
Routing: &config.RoutingConfig{
|
|
Enabled: true,
|
|
LightModel: "qwen-light",
|
|
Threshold: 0.99,
|
|
},
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "gemini-main",
|
|
Model: "gemini/gemini-2.5-flash",
|
|
APIBase: heavyServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("heavy-key"),
|
|
},
|
|
{
|
|
ModelName: "qwen-light",
|
|
Model: "ollama/qwen2.5:0.5b",
|
|
APIBase: lightServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("light-key"),
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider, _, err := providers.CreateProvider(cfg)
|
|
if err != nil {
|
|
t.Fatalf("CreateProvider() error = %v", err)
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hi",
|
|
})
|
|
if resp != "light reply" {
|
|
t.Fatalf("response = %q, want %q", resp, "light reply")
|
|
}
|
|
if heavyCalls != 0 {
|
|
t.Fatalf("heavy calls = %d, want 0", heavyCalls)
|
|
}
|
|
if lightCalls != 1 {
|
|
t.Fatalf("light calls = %d, want 1", lightCalls)
|
|
}
|
|
}
|
|
|
|
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
|
|
// bug #2140. It verifies that when the primary model returns a rate-limit error
|
|
// the fallback closure routes the retry to the fallback model's own provider
|
|
// (its own api_base), not back to the primary provider's endpoint.
|
|
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
|
workspace := t.TempDir()
|
|
|
|
primaryCalls := 0
|
|
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
primaryCalls++
|
|
// Return 429 so FallbackChain classifies this as retriable and moves on.
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"error": map[string]any{
|
|
"message": "rate limit exceeded",
|
|
"type": "rate_limit_error",
|
|
},
|
|
})
|
|
}))
|
|
defer primaryServer.Close()
|
|
|
|
fallbackCalls := 0
|
|
fallbackServer := newStrictChatCompletionTestServer(
|
|
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
|
|
)
|
|
defer fallbackServer.Close()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: workspace,
|
|
ModelName: "mistral-primary",
|
|
ModelFallbacks: []string{"gemma-fallback"},
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 3,
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "mistral-primary",
|
|
Model: "openrouter/mistralai/mistral-small-3.1",
|
|
APIBase: primaryServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("primary-key"),
|
|
Workspace: workspace,
|
|
},
|
|
{
|
|
ModelName: "gemma-fallback",
|
|
Model: "openrouter/gemma-3-27b-it",
|
|
APIBase: fallbackServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
|
Workspace: workspace,
|
|
},
|
|
},
|
|
}
|
|
|
|
provider, _, err := providers.CreateProvider(cfg)
|
|
if err != nil {
|
|
t.Fatalf("CreateProvider() error = %v", err)
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hi",
|
|
})
|
|
|
|
if resp != "fallback reply" {
|
|
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
|
|
}
|
|
if primaryCalls == 0 {
|
|
t.Fatal("primary server was never called; expected at least one attempt")
|
|
}
|
|
if fallbackCalls != 1 {
|
|
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
|
}
|
|
}
|
|
|
|
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
|
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
|
// and the fallback closure falls back to activeProvider instead of panicking.
|
|
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
|
|
workspace := t.TempDir()
|
|
|
|
// Primary server: returns 429 on first call, succeeds on second.
|
|
// Both the primary and the unregistered fallback share this server
|
|
// (same api_base) so activeProvider routes both calls here.
|
|
callCount := 0
|
|
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
callCount++
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if callCount == 1 {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
|
|
})
|
|
return
|
|
}
|
|
// Second call (fallback via activeProvider) succeeds.
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
|
|
},
|
|
})
|
|
}))
|
|
defer primaryServer.Close()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: workspace,
|
|
ModelName: "primary-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 3,
|
|
// No model_list entry for this alias — absent from CandidateProviders.
|
|
ModelFallbacks: []string{"openrouter/fallback-model"},
|
|
},
|
|
},
|
|
ModelList: []*config.ModelConfig{
|
|
{
|
|
ModelName: "primary-model",
|
|
Model: "openrouter/primary-model",
|
|
APIBase: primaryServer.URL,
|
|
APIKeys: config.SimpleSecureStrings("primary-key"),
|
|
Workspace: workspace,
|
|
},
|
|
},
|
|
}
|
|
|
|
provider, _, err := providers.CreateProvider(cfg)
|
|
if err != nil {
|
|
t.Fatalf("CreateProvider() error = %v", err)
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
helper := testHelper{al: al}
|
|
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hi",
|
|
})
|
|
|
|
if resp != "active provider reply" {
|
|
t.Fatalf("response = %q, want %q", resp, "active provider reply")
|
|
}
|
|
if callCount < 2 {
|
|
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
|
|
}
|
|
}
|
|
|
|
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
|
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &simpleMockProvider{response: "File operation complete"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
// ReadFileTool returns SilentResult, which should not send user message
|
|
ctx := context.Background()
|
|
msg := bus.InboundMessage{
|
|
Channel: "test",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "read test.txt",
|
|
SessionKey: "test-session",
|
|
}
|
|
|
|
response := helper.executeAndGetResponse(t, ctx, msg)
|
|
|
|
// Silent tool should return the LLM's response directly
|
|
if response != "File operation complete" {
|
|
t.Errorf("Expected 'File operation complete', got: %s", response)
|
|
}
|
|
}
|
|
|
|
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
|
|
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &simpleMockProvider{response: "Command output: hello world"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
helper := testHelper{al: al}
|
|
|
|
// ExecTool returns UserResult, which should send user message
|
|
ctx := context.Background()
|
|
msg := bus.InboundMessage{
|
|
Channel: "test",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "run hello",
|
|
SessionKey: "test-session",
|
|
}
|
|
|
|
response := helper.executeAndGetResponse(t, ctx, msg)
|
|
|
|
// User-facing tool should include the output in final response
|
|
if response != "Command output: hello world" {
|
|
t.Errorf("Expected 'Command output: hello world', got: %s", response)
|
|
}
|
|
}
|
|
|
|
// failFirstMockProvider fails on the first N calls with a specific error
|
|
type failFirstMockProvider struct {
|
|
failures int
|
|
currentCall int
|
|
failError error
|
|
successResp string
|
|
}
|
|
|
|
func (m *failFirstMockProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.currentCall++
|
|
if m.currentCall <= m.failures {
|
|
return nil, m.failError
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: m.successResp,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *failFirstMockProvider) GetDefaultModel() string {
|
|
return "mock-fail-model"
|
|
}
|
|
|
|
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
|
|
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
|
|
// Create a provider that fails once with a context error
|
|
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
|
provider := &failFirstMockProvider{
|
|
failures: 1,
|
|
failError: contextErr,
|
|
successResp: "Recovered from context error",
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Inject some history to simulate a full context.
|
|
// Session history only stores user/assistant/tool messages — the system
|
|
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
|
sessionKey := "test-session-context"
|
|
history := []providers.Message{
|
|
{Role: "user", Content: "Old message 1"},
|
|
{Role: "assistant", Content: "Old response 1"},
|
|
{Role: "user", Content: "Old message 2"},
|
|
{Role: "assistant", Content: "Old response 2"},
|
|
{Role: "user", Content: "Trigger message"},
|
|
}
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("No default agent found")
|
|
}
|
|
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
|
|
|
// Call ProcessDirectWithChannel
|
|
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
|
response, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"Trigger message",
|
|
sessionKey,
|
|
"test",
|
|
"test-chat",
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Expected success after retry, got error: %v", err)
|
|
}
|
|
|
|
if response != "Recovered from context error" {
|
|
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
|
|
}
|
|
|
|
// We expect 2 calls: 1st failed, 2nd succeeded
|
|
if provider.currentCall != 2 {
|
|
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
|
|
}
|
|
|
|
// Check final history length
|
|
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
|
// We verify that the history has been modified (compressed)
|
|
// Original length: 5
|
|
// Expected behavior: compression drops ~50% of Turns
|
|
// Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
|
|
if len(finalHistory) >= 7 {
|
|
t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
|
|
}
|
|
}
|
|
|
|
type visionUnsupportedMediaProvider struct {
|
|
calls int
|
|
mediaSeen []bool
|
|
}
|
|
|
|
func (p *visionUnsupportedMediaProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.calls++
|
|
|
|
hasMedia := false
|
|
for _, msg := range messages {
|
|
for _, ref := range msg.Media {
|
|
if strings.TrimSpace(ref) != "" {
|
|
hasMedia = true
|
|
break
|
|
}
|
|
}
|
|
if hasMedia {
|
|
break
|
|
}
|
|
}
|
|
p.mediaSeen = append(p.mediaSeen, hasMedia)
|
|
|
|
if hasMedia {
|
|
return nil, fmt.Errorf("API request failed: " +
|
|
"Status: 404 Body: {\"error\":{\"message\":\"No endpoints found that support image input\"}}")
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "ok",
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (p *visionUnsupportedMediaProvider) GetDefaultModel() string {
|
|
return "mock-fail-model"
|
|
}
|
|
|
|
func TestAgentLoop_VisionUnsupportedErrorStripsSessionMedia(t *testing.T) {
|
|
workspace := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: workspace,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 3,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &visionUnsupportedMediaProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
sessionKey := "agent:main:telegram:direct:user1"
|
|
|
|
timeoutCtx, cancel := context.WithTimeout(context.Background(), responseTimeout)
|
|
defer cancel()
|
|
|
|
resp, err := al.processMessage(timeoutCtx, testInboundMessage(bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
MessageID: "m1",
|
|
},
|
|
Content: "describe this",
|
|
Media: []string{"data:image/png;base64,abc123"},
|
|
SessionKey: sessionKey,
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if resp != "ok" {
|
|
t.Fatalf("response = %q, want %q", resp, "ok")
|
|
}
|
|
if provider.calls != 2 {
|
|
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
|
|
}
|
|
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
|
|
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
|
|
}
|
|
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
history := agent.Sessions.GetHistory(sessionKey)
|
|
for i, msg := range history {
|
|
if len(msg.Media) > 0 {
|
|
t.Fatalf("history[%d].Media = %v, want no media after stripping", i, msg.Media)
|
|
}
|
|
}
|
|
|
|
timeoutCtx2, cancel2 := context.WithTimeout(context.Background(), responseTimeout)
|
|
defer cancel2()
|
|
|
|
resp2, err := al.processMessage(timeoutCtx2, testInboundMessage(bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
MessageID: "m2",
|
|
},
|
|
Content: "hello again",
|
|
SessionKey: sessionKey,
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() second call error = %v", err)
|
|
}
|
|
if resp2 != "ok" {
|
|
t.Fatalf("second response = %q, want %q", resp2, "ok")
|
|
}
|
|
if provider.calls != 3 {
|
|
t.Fatalf("calls after second turn = %d, want %d", provider.calls, 3)
|
|
}
|
|
if !slices.Equal(provider.mediaSeen, []bool{true, false, false}) {
|
|
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false, false})
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 3,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &simpleMockProvider{response: ""}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
|
|
if err != nil {
|
|
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
|
}
|
|
if response != defaultResponse {
|
|
t.Fatalf("response = %q, want %q", response, defaultResponse)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 1,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolLimitOnlyProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
al.RegisterTool(&toolLimitTestTool{})
|
|
|
|
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
|
|
if err != nil {
|
|
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
|
}
|
|
if response != toolLimitResponse {
|
|
t.Fatalf("response = %q, want %q", response, toolLimitResponse)
|
|
}
|
|
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("No default agent found")
|
|
}
|
|
route := al.registry.ResolveRoute(bus.InboundContext{
|
|
Channel: "test",
|
|
ChatType: "direct",
|
|
SenderID: "cron",
|
|
})
|
|
history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
|
Channel: "test",
|
|
SenderID: "cron",
|
|
ChatID: "chat1",
|
|
})).SessionKey)
|
|
if len(history) != 4 {
|
|
t.Fatalf("history len = %d, want 4", len(history))
|
|
}
|
|
assertRoles(t, history, "user", "assistant", "tool", "assistant")
|
|
if history[3].Content != toolLimitResponse {
|
|
t.Fatalf("final assistant content = %q, want %q", history[3].Content, toolLimitResponse)
|
|
}
|
|
}
|
|
|
|
// TestProcessDirectWithChannel_TriggersMCPInitialization verifies that
|
|
// ProcessDirectWithChannel triggers MCP initialization when MCP is enabled.
|
|
// Note: Manager is only initialized when at least one MCP server is configured
|
|
// and successfully connected.
|
|
func TestProcessDirectWithChannel_TriggersMCPInitialization(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
// Test with MCP enabled but no servers - should not initialize manager
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
MCP: config.MCPConfig{
|
|
ToolConfig: config.ToolConfig{
|
|
Enabled: true,
|
|
},
|
|
// No servers configured - manager should not be initialized
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
defer al.Close()
|
|
|
|
if al.mcp.hasManager() {
|
|
t.Fatal("expected MCP manager to be nil before first direct processing")
|
|
}
|
|
|
|
_, err = al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"hello",
|
|
"session-1",
|
|
"cli",
|
|
"direct",
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
|
}
|
|
|
|
// Manager should not be initialized when no servers are configured
|
|
if al.mcp.hasManager() {
|
|
t.Fatal("expected MCP manager to be nil when no servers are configured")
|
|
}
|
|
}
|
|
|
|
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
|
chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel manager: %v", err)
|
|
}
|
|
for name, id := range map[string]string{
|
|
"whatsapp": "rid-whatsapp",
|
|
"telegram": "rid-telegram",
|
|
"feishu": "rid-feishu",
|
|
"discord": "rid-discord",
|
|
"maixcam": "rid-maixcam",
|
|
"qq": "rid-qq",
|
|
"dingtalk": "rid-dingtalk",
|
|
"slack": "rid-slack",
|
|
"line": "rid-line",
|
|
"onebot": "rid-onebot",
|
|
"wecom": "rid-wecom",
|
|
} {
|
|
chManager.RegisterChannel(name, &fakeChannel{id: id})
|
|
}
|
|
al.SetChannelManager(chManager)
|
|
tests := []struct {
|
|
channel string
|
|
wantID string
|
|
}{
|
|
{channel: "whatsapp", wantID: "rid-whatsapp"},
|
|
{channel: "telegram", wantID: "rid-telegram"},
|
|
{channel: "feishu", wantID: "rid-feishu"},
|
|
{channel: "discord", wantID: "rid-discord"},
|
|
{channel: "maixcam", wantID: "rid-maixcam"},
|
|
{channel: "qq", wantID: "rid-qq"},
|
|
{channel: "dingtalk", wantID: "rid-dingtalk"},
|
|
{channel: "slack", wantID: "rid-slack"},
|
|
{channel: "line", wantID: "rid-line"},
|
|
{channel: "onebot", wantID: "rid-onebot"},
|
|
{channel: "wecom", wantID: "rid-wecom"},
|
|
{channel: "unknown", wantID: ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.channel, func(t *testing.T) {
|
|
got := al.targetReasoningChannelID(tt.channel)
|
|
if got != tt.wantID {
|
|
t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleReasoning(t *testing.T) {
|
|
newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) {
|
|
t.Helper()
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus
|
|
}
|
|
|
|
t.Run("skips when any required field is empty", func(t *testing.T) {
|
|
al, msgBus := newLoop(t)
|
|
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
for {
|
|
select {
|
|
case msg, ok := <-msgBus.OutboundChan():
|
|
if !ok {
|
|
t.Fatalf("expected no outbound message, got %+v", msg)
|
|
}
|
|
if msg.Content == "reasoning" {
|
|
t.Fatalf("expected no message for empty chatID, got %+v", msg)
|
|
}
|
|
return
|
|
case <-ctx.Done():
|
|
t.Log("expected an outbound message, got none within timeout")
|
|
return
|
|
default:
|
|
// Continue to check for message
|
|
time.Sleep(5 * time.Millisecond) // Avoid busy loop
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("publishes one message for non telegram", func(t *testing.T) {
|
|
al, msgBus := newLoop(t)
|
|
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
|
|
|
|
msg, ok := <-msgBus.OutboundChan()
|
|
if !ok {
|
|
t.Fatal("expected an outbound message")
|
|
}
|
|
if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" {
|
|
t.Fatalf("unexpected outbound message: %+v", msg)
|
|
}
|
|
})
|
|
|
|
t.Run("publishes one message for telegram", func(t *testing.T) {
|
|
al, msgBus := newLoop(t)
|
|
reasoning := "hello telegram reasoning"
|
|
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatal("expected an outbound message, got none within timeout")
|
|
return
|
|
case msg, ok := <-msgBus.OutboundChan():
|
|
if !ok {
|
|
t.Fatal("expected outbound message")
|
|
}
|
|
|
|
if msg.Channel != "telegram" {
|
|
t.Fatalf("expected telegram channel message, got %+v", msg)
|
|
}
|
|
if msg.ChatID != "tg-chat" {
|
|
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
|
}
|
|
if msg.Content != reasoning {
|
|
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
})
|
|
t.Run("expired ctx", func(t *testing.T) {
|
|
al, msgBus := newLoop(t)
|
|
reasoning := "hello telegram reasoning"
|
|
|
|
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
|
|
|
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer consumeCancel()
|
|
|
|
for {
|
|
select {
|
|
case msg, ok := <-msgBus.OutboundChan():
|
|
if !ok {
|
|
t.Fatalf("expected no outbound message, but received: %+v", msg)
|
|
}
|
|
t.Logf("Received unexpected outbound message: %+v", msg)
|
|
return
|
|
case <-consumeCtx.Done():
|
|
t.Fatalf("failed: no message received within timeout")
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("returns promptly when bus is full", func(t *testing.T) {
|
|
al, msgBus := newLoop(t)
|
|
|
|
// Fill the outbound bus buffer until a publish would block.
|
|
// Use a short timeout to detect when the buffer is full,
|
|
// rather than hardcoding the buffer size.
|
|
for i := 0; ; i++ {
|
|
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
|
|
Context: bus.NewOutboundContext("filler", "filler", ""),
|
|
Content: fmt.Sprintf("filler-%d", i),
|
|
})
|
|
fillCancel()
|
|
if err != nil {
|
|
// Buffer is full (timed out trying to send).
|
|
break
|
|
}
|
|
}
|
|
|
|
// Use a short-deadline parent context to bound the test.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
al.handleReasoning(ctx, "should timeout", "slack", "channel-full")
|
|
elapsed := time.Since(start)
|
|
|
|
// handleReasoning uses a 5s internal timeout, but the parent ctx
|
|
// expires in 500ms. It should return within ~500ms, not 5s.
|
|
if elapsed > 2*time.Second {
|
|
t.Fatalf("handleReasoning blocked too long (%v); expected prompt return", elapsed)
|
|
}
|
|
|
|
// Drain the bus and verify the reasoning message was NOT published
|
|
// (it should have been dropped due to timeout).
|
|
timeer := time.After(1 * time.Second)
|
|
for {
|
|
select {
|
|
case <-timeer:
|
|
t.Logf(
|
|
"no reasoning message received after draining bus for 1s, as expected,length=%d",
|
|
len(msgBus.OutboundChan()),
|
|
)
|
|
return
|
|
case msg, ok := <-msgBus.OutboundChan():
|
|
if !ok {
|
|
break
|
|
}
|
|
if msg.Content == "should timeout" {
|
|
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &reasoningContentProvider{
|
|
response: "final answer",
|
|
reasoningContent: "thinking trace",
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
chManager, err := channels.NewManager(&config.Config{}, msgBus, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel manager: %v", err)
|
|
}
|
|
chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"})
|
|
al.SetChannelManager(chManager)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hello",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "final answer" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "final answer")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
if outbound.Channel != "telegram" {
|
|
t.Fatalf("reasoning channel = %q, want %q", outbound.Channel, "telegram")
|
|
}
|
|
if outbound.ChatID != "reason-chat" {
|
|
t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat")
|
|
}
|
|
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" {
|
|
t.Fatalf("unexpected reasoning context: %+v", outbound.Context)
|
|
}
|
|
if outbound.Content != "thinking trace" {
|
|
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("expected reasoning content to be published to reasoning channel")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &reasoningContentProvider{
|
|
response: "final answer",
|
|
reasoningContent: "thinking trace",
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "pico",
|
|
SenderID: "user1",
|
|
ChatID: "pico:test-session",
|
|
Content: "hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "final answer" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "final answer")
|
|
}
|
|
|
|
var thoughtMsg *bus.OutboundMessage
|
|
deadline := time.After(3 * time.Second)
|
|
|
|
for thoughtMsg == nil {
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
msg := outbound
|
|
if msg.Content == "thinking trace" {
|
|
thoughtMsg = &msg
|
|
}
|
|
case <-deadline:
|
|
t.Fatal("expected thought outbound message for pico")
|
|
}
|
|
}
|
|
|
|
if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" {
|
|
t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID)
|
|
}
|
|
if thoughtMsg.Context.Raw[metadataKeyMessageKind] != messageKindThought {
|
|
t.Fatalf(
|
|
"thought metadata kind = %q, want %q",
|
|
thoughtMsg.Context.Raw[metadataKeyMessageKind],
|
|
messageKindThought,
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
heartbeatFile := filepath.Join(tmpDir, "heartbeat-task.txt")
|
|
if err := os.WriteFile(heartbeatFile, []byte("heartbeat task"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
ToolFeedback: config.ToolFeedbackConfig{
|
|
Enabled: true,
|
|
MaxArgsLength: 300,
|
|
},
|
|
},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
ReadFile: config.ReadFileToolConfig{
|
|
Enabled: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.ProcessHeartbeat(context.Background(), "check heartbeat tasks", "telegram", "chat-1")
|
|
if err != nil {
|
|
t.Fatalf("ProcessHeartbeat() error = %v", err)
|
|
}
|
|
if response != "HEARTBEAT_OK" {
|
|
t.Fatalf("ProcessHeartbeat() response = %q, want %q", response, "HEARTBEAT_OK")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
t.Fatalf("expected no outbound tool feedback during heartbeat, got %+v", outbound)
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
heartbeatFile := filepath.Join(tmpDir, "tool-feedback.txt")
|
|
if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
ToolFeedback: config.ToolFeedbackConfig{
|
|
Enabled: true,
|
|
MaxArgsLength: 300,
|
|
},
|
|
},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
ReadFile: config.ReadFileToolConfig{
|
|
Enabled: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user-1",
|
|
ChatID: "chat-1",
|
|
Content: "check tool feedback",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "HEARTBEAT_OK" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "HEARTBEAT_OK")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
|
|
if outbound.Channel != "telegram" {
|
|
t.Fatalf("tool feedback channel = %q, want %q", outbound.Channel, "telegram")
|
|
}
|
|
if outbound.ChatID != "chat-1" {
|
|
t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1")
|
|
}
|
|
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
|
t.Fatalf("unexpected tool feedback context: %+v", outbound.Context)
|
|
}
|
|
if !strings.Contains(outbound.Content, "`read_file`") {
|
|
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) {
|
|
t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, "check tool feedback") {
|
|
t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, "\"path\":") {
|
|
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
|
|
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
|
|
}
|
|
if strings.Contains(outbound.Content, "Previous turn explanation") {
|
|
t.Fatalf("tool feedback content = %q, want no previous assistant fallback", outbound.Content)
|
|
}
|
|
if outbound.AgentID != "main" {
|
|
t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID)
|
|
}
|
|
if outbound.SessionKey == "" {
|
|
t.Fatal("expected tool feedback to carry session_key")
|
|
}
|
|
if outbound.Scope == nil || outbound.Scope.AgentID != "main" || outbound.Scope.Channel != "telegram" {
|
|
t.Fatalf("expected tool feedback scope, got %+v", outbound.Scope)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected outbound tool feedback for regular messages")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_PersistsReasoningContentInSessionHistory(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &reasoningContentProvider{
|
|
response: "final answer",
|
|
reasoningContent: "thinking trace",
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "pico",
|
|
SenderID: "user1",
|
|
ChatID: "pico:test-session",
|
|
Content: "hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "final answer" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "final answer")
|
|
}
|
|
|
|
store := al.GetRegistry().GetDefaultAgent().Sessions
|
|
sessionKeys := store.ListSessions()
|
|
if len(sessionKeys) != 1 {
|
|
t.Fatalf("session keys = %v, want exactly 1 active session", sessionKeys)
|
|
}
|
|
history := store.GetHistory(sessionKeys[0])
|
|
if len(history) < 2 {
|
|
t.Fatalf("session history len = %d, want at least 2", len(history))
|
|
}
|
|
|
|
last := history[len(history)-1]
|
|
if last.Role != "assistant" {
|
|
t.Fatalf("last message role = %q, want assistant", last.Role)
|
|
}
|
|
if last.Content != "final answer" {
|
|
t.Fatalf("last message content = %q, want %q", last.Content, "final answer")
|
|
}
|
|
if last.ReasoningContent != "thinking trace" {
|
|
t.Fatalf("last message reasoning_content = %q, want %q", last.ReasoningContent, "thinking trace")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_PersistsReasoningToolResponseAsSingleAssistantRecord(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
inspectPath := filepath.Join(tmpDir, "inspect.txt")
|
|
if err := os.WriteFile(inspectPath, []byte("inspect me"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile(inspectPath) error = %v", err)
|
|
}
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = tmpDir
|
|
cfg.Agents.Defaults.ModelName = "test-model"
|
|
cfg.Agents.Defaults.MaxTokens = 4096
|
|
cfg.Agents.Defaults.MaxToolIterations = 10
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &reasoningVisibleToolProvider{filePath: inspectPath}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "DONE" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "DONE")
|
|
}
|
|
|
|
store := al.GetRegistry().GetDefaultAgent().Sessions
|
|
sessionKeys := store.ListSessions()
|
|
if len(sessionKeys) != 1 {
|
|
t.Fatalf("session keys = %v, want exactly 1 active session", sessionKeys)
|
|
}
|
|
|
|
history := store.GetHistory(sessionKeys[0])
|
|
if len(history) < 3 {
|
|
t.Fatalf("session history len = %d, want at least 3", len(history))
|
|
}
|
|
|
|
var assistantWithToolCall *providers.Message
|
|
for i := range history {
|
|
msg := history[i]
|
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
|
assistantWithToolCall = &msg
|
|
break
|
|
}
|
|
}
|
|
if assistantWithToolCall == nil {
|
|
t.Fatal("expected assistant history record with tool_calls")
|
|
}
|
|
if assistantWithToolCall.Content != "I'll inspect that file now." {
|
|
t.Fatalf("assistant content = %q, want %q", assistantWithToolCall.Content, "I'll inspect that file now.")
|
|
}
|
|
if assistantWithToolCall.ReasoningContent != "Read the file before answering." {
|
|
t.Fatalf("assistant reasoning_content = %q, want preserved", assistantWithToolCall.ReasoningContent)
|
|
}
|
|
if len(assistantWithToolCall.ToolCalls) != 1 {
|
|
t.Fatalf("assistant tool calls = %+v, want single read_file tool", assistantWithToolCall.ToolCalls)
|
|
}
|
|
if got := providers.NormalizeToolCall(assistantWithToolCall.ToolCalls[0]).Name; got != "read_file" {
|
|
t.Fatalf("assistant tool calls = %+v, want single read_file tool", assistantWithToolCall.ToolCalls)
|
|
}
|
|
|
|
sessionDir := filepath.Join(tmpDir, "sessions")
|
|
entries, err := os.ReadDir(sessionDir)
|
|
if err != nil {
|
|
t.Fatalf("ReadDir(%q) error = %v", sessionDir, err)
|
|
}
|
|
|
|
var jsonlPath string
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
|
|
continue
|
|
}
|
|
jsonlPath = filepath.Join(sessionDir, entry.Name())
|
|
break
|
|
}
|
|
if jsonlPath == "" {
|
|
t.Fatal("expected session jsonl file to be created")
|
|
}
|
|
|
|
data, err := os.ReadFile(jsonlPath)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile(%q) error = %v", jsonlPath, err)
|
|
}
|
|
|
|
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
|
if len(lines) < 3 {
|
|
t.Fatalf("jsonl lines = %d, want at least 3", len(lines))
|
|
}
|
|
|
|
matchingRecords := 0
|
|
for _, line := range lines {
|
|
var msg providers.Message
|
|
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
|
t.Fatalf("Unmarshal(jsonl line) error = %v", err)
|
|
}
|
|
if msg.Role != "assistant" {
|
|
continue
|
|
}
|
|
if msg.Content == "I'll inspect that file now." || msg.ReasoningContent == "Read the file before answering." {
|
|
matchingRecords++
|
|
toolName := ""
|
|
if len(msg.ToolCalls) == 1 {
|
|
toolName = providers.NormalizeToolCall(msg.ToolCalls[0]).Name
|
|
}
|
|
if msg.Content != "I'll inspect that file now." ||
|
|
msg.ReasoningContent != "Read the file before answering." ||
|
|
len(msg.ToolCalls) != 1 ||
|
|
toolName != "read_file" {
|
|
t.Fatalf("assistant jsonl record = %+v, want content+reasoning+tool_calls in one line", msg)
|
|
}
|
|
}
|
|
}
|
|
if matchingRecords != 1 {
|
|
t.Fatalf("matching assistant jsonl records = %d, want exactly 1 canonical assistant record", matchingRecords)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
heartbeatFile := filepath.Join(tmpDir, "tool-feedback-reasoning.txt")
|
|
if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
ToolFeedback: config.ToolFeedbackConfig{
|
|
Enabled: true,
|
|
MaxArgsLength: 300,
|
|
},
|
|
},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
ReadFile: config.ReadFileToolConfig{
|
|
Enabled: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolFeedbackReasoningProvider{filePath: heartbeatFile}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user-1",
|
|
ChatID: "chat-1",
|
|
Content: "check reasoning fallback",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "DONE" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "DONE")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
|
|
if !strings.Contains(outbound.Content, "`read_file`") {
|
|
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) {
|
|
t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, "check reasoning fallback") {
|
|
t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, "\"path\":") {
|
|
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
|
|
}
|
|
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
|
|
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
|
|
}
|
|
if strings.Contains(outbound.Content, "Read README.md first") {
|
|
t.Fatalf("tool feedback content = %q, should not leak hidden reasoning", outbound.Content)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected outbound tool feedback without leaking reasoning")
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_DoesNotPublishToolFeedbackForDiscordWhenDisabled(t *testing.T) {
|
|
assertToolFeedbackNotPublishedWhenDisabled(t, "discord")
|
|
}
|
|
|
|
func assertToolFeedbackNotPublishedWhenDisabled(t *testing.T, channel string) {
|
|
t.Helper()
|
|
|
|
tmpDir := t.TempDir()
|
|
heartbeatFile := filepath.Join(tmpDir, "tool-feedback-"+channel+".txt")
|
|
if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
Tools: config.ToolsConfig{
|
|
ReadFile: config.ReadFileToolConfig{
|
|
Enabled: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: channel,
|
|
SenderID: "user-1",
|
|
ChatID: "chat-1",
|
|
Content: "check tool feedback",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "HEARTBEAT_OK" {
|
|
t.Fatalf("processMessage() response = %q, want %q", response, "HEARTBEAT_OK")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
t.Fatalf("expected no outbound tool feedback for %s when disabled, got %+v", channel, outbound)
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_DoesNotPublishToolFeedbackForTelegramWhenDisabled(t *testing.T) {
|
|
assertToolFeedbackNotPublishedWhenDisabled(t, "telegram")
|
|
}
|
|
|
|
func TestProcessMessage_DoesNotPublishToolFeedbackForFeishuWhenDisabled(t *testing.T) {
|
|
assertToolFeedbackNotPublishedWhenDisabled(t, "feishu")
|
|
}
|
|
|
|
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.Workspace = t.TempDir()
|
|
cfg.Agents.Defaults.ModelName = "test-model"
|
|
cfg.Agents.Defaults.MaxTokens = 4096
|
|
cfg.Agents.Defaults.MaxToolIterations = 10
|
|
cfg.Session.Dimensions = []string{"chat"}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &messageToolProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "telegram",
|
|
SenderID: "user-1",
|
|
ChatID: "chat-1",
|
|
Content: "send a direct message",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response == "" {
|
|
t.Fatal("expected processMessage() to return a final loop response")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
if outbound.Content != "direct tool message" {
|
|
t.Fatalf("outbound content = %q, want direct tool message", outbound.Content)
|
|
}
|
|
if outbound.AgentID != "main" {
|
|
t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID)
|
|
}
|
|
if outbound.SessionKey == "" {
|
|
t.Fatal("expected message tool outbound to carry session_key")
|
|
}
|
|
if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" {
|
|
t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope)
|
|
}
|
|
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
|
t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected message tool outbound")
|
|
}
|
|
}
|
|
|
|
func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &picoDistinctToolCallContentProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
agent.Tools.Register(&toolLimitTestTool{})
|
|
|
|
runCtx, runCancel := context.WithCancel(context.Background())
|
|
defer runCancel()
|
|
|
|
runDone := make(chan error, 1)
|
|
go func() {
|
|
runDone <- al.Run(runCtx)
|
|
}()
|
|
|
|
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
|
Channel: "pico",
|
|
SenderID: "user-1",
|
|
ChatID: "session-1",
|
|
Content: "run with tools",
|
|
}); err != nil {
|
|
t.Fatalf("PublishInbound() error = %v", err)
|
|
}
|
|
|
|
outputs := make([]bus.OutboundMessage, 0, 3)
|
|
deadline := time.After(2 * time.Second)
|
|
for len(outputs) < 3 {
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
outputs = append(outputs, outbound)
|
|
case <-deadline:
|
|
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
|
|
}
|
|
}
|
|
|
|
if outputs[0].Content != "intermediate model text" {
|
|
t.Fatalf("first outbound content = %q, want %q", outputs[0].Content, "intermediate model text")
|
|
}
|
|
if outputs[1].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
|
|
t.Fatalf("second outbound = %+v, want tool_calls message", outputs[1])
|
|
}
|
|
if !strings.Contains(outputs[1].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
|
|
t.Fatalf("second outbound tool_calls = %q, want tool name", outputs[1].Context.Raw[metadataKeyToolCalls])
|
|
}
|
|
if outputs[2].Content != "final model text" {
|
|
t.Fatalf("third outbound content = %q, want %q", outputs[2].Content, "final model text")
|
|
}
|
|
|
|
runCancel()
|
|
select {
|
|
case err := <-runDone:
|
|
if err != nil {
|
|
t.Fatalf("Run() error = %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for Run() to exit")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
if outbound.Content == "final model text" {
|
|
t.Fatalf("unexpected duplicate final pico output: %+v", outbound)
|
|
}
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestRunAgentLoop_PicoSkipsInterimPublishWhenNotAllowed(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &picoInterleavedContentProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
agent.Tools.Register(&toolLimitTestTool{})
|
|
|
|
response, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "agent:main:pico:session-1",
|
|
Channel: "pico",
|
|
ChatID: "session-1",
|
|
UserMessage: "run with tools",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
AllowInterimPicoPublish: false,
|
|
SuppressToolFeedback: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop() error = %v", err)
|
|
}
|
|
if response != "final model text" {
|
|
t.Fatalf("runAgentLoop() response = %q, want %q", response, "final model text")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
t.Fatalf("unexpected outbound message when interim publish disabled: %+v", outbound)
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
ToolFeedback: config.ToolFeedbackConfig{
|
|
Enabled: true,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &picoInterleavedContentProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
agent.Tools.Register(&toolLimitTestTool{})
|
|
|
|
runCtx, runCancel := context.WithCancel(context.Background())
|
|
defer runCancel()
|
|
|
|
runDone := make(chan error, 1)
|
|
go func() {
|
|
runDone <- al.Run(runCtx)
|
|
}()
|
|
|
|
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
|
Channel: "pico",
|
|
SenderID: "user-1",
|
|
ChatID: "session-1",
|
|
Content: "run with tools",
|
|
}); err != nil {
|
|
t.Fatalf("PublishInbound() error = %v", err)
|
|
}
|
|
|
|
outputs := make([]bus.OutboundMessage, 0, 3)
|
|
deadline := time.After(2 * time.Second)
|
|
for len(outputs) < 2 {
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
outputs = append(outputs, outbound)
|
|
case <-deadline:
|
|
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
|
|
}
|
|
}
|
|
|
|
if outputs[0].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
|
|
t.Fatalf("first outbound = %+v, want tool_calls message", outputs[0])
|
|
}
|
|
if outputs[0].Content != "" {
|
|
t.Fatalf("first outbound content = %q, want empty tool_calls content", outputs[0].Content)
|
|
}
|
|
if !strings.Contains(outputs[0].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
|
|
t.Fatalf("first outbound tool_calls = %q, want tool name", outputs[0].Context.Raw[metadataKeyToolCalls])
|
|
}
|
|
if outputs[1].Content != "final model text" {
|
|
t.Fatalf("second outbound content = %q, want %q", outputs[1].Content, "final model text")
|
|
}
|
|
|
|
runCancel()
|
|
select {
|
|
case err := <-runDone:
|
|
if err != nil {
|
|
t.Fatalf("Run() error = %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for Run() to exit")
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
t.Fatalf("unexpected extra pico output after tool feedback + final reply: %+v", outbound)
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_ImageInjectsPathTag(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
|
|
pngPath := filepath.Join(dir, "test.png")
|
|
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
|
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
|
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
|
0x00, 0x00, 0x00, // no interlace
|
|
0x90, 0x77, 0x53, 0xDE, // CRC
|
|
}
|
|
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "describe this", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
|
|
}
|
|
localPath, _, _ := store.ResolveWithMeta(ref)
|
|
expectedContent := "describe this [image:" + localPath + "]"
|
|
if result[0].Content != expectedContent {
|
|
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_ToolRoleImageAppendedAsUserMessage(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
pngPath := filepath.Join(dir, "tool-result.png")
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
|
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
|
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
|
0x00, 0x00, 0x00, // no interlace
|
|
0x90, 0x77, 0x53, 0xDE, // CRC
|
|
}
|
|
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "tool", Content: "Image loaded", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
// Tool message should have path tag but no base64
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media in tool message, got %d", len(result[0].Media))
|
|
}
|
|
localPath, _, _ := store.ResolveWithMeta(ref)
|
|
if !strings.Contains(result[0].Content, "[image:"+localPath+"]") {
|
|
t.Fatalf("expected image path tag in tool content, got %q", result[0].Content)
|
|
}
|
|
|
|
// A synthetic user message with base64 should follow
|
|
if len(result) != 2 {
|
|
t.Fatalf("expected 2 messages (tool + synthetic user), got %d", len(result))
|
|
}
|
|
if result[1].Role != "user" {
|
|
t.Fatalf("expected synthetic message role=user, got %q", result[1].Role)
|
|
}
|
|
if len(result[1].Media) != 1 {
|
|
t.Fatalf("expected 1 base64 media in synthetic user message, got %d", len(result[1].Media))
|
|
}
|
|
if !strings.HasPrefix(result[1].Media[0], "data:image/png;base64,") {
|
|
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[1].Media[0][:40])
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_MultiToolCallPreservesOrdering(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
// Create image for tool #1
|
|
pngPath := filepath.Join(dir, "loaded.png")
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
|
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
|
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
|
}
|
|
os.WriteFile(pngPath, pngHeader, 0o644)
|
|
imgRef, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
|
|
|
// Simulate: assistant called load_image + read_file, two tool results follow
|
|
messages := []providers.Message{
|
|
{Role: "assistant", Content: "Let me load the image and read the file."},
|
|
{Role: "tool", Content: "Image loaded [image: photo]", Media: []string{imgRef}},
|
|
{Role: "tool", Content: "file contents here"},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
// assistant, tool#1, tool#2 must remain contiguous — no user in between
|
|
if result[0].Role != "assistant" {
|
|
t.Fatalf("result[0] expected assistant, got %q", result[0].Role)
|
|
}
|
|
if result[1].Role != "tool" {
|
|
t.Fatalf("result[1] expected tool, got %q", result[1].Role)
|
|
}
|
|
if result[2].Role != "tool" {
|
|
t.Fatalf("result[2] expected tool, got %q", result[2].Role)
|
|
}
|
|
|
|
// Synthetic user message should come AFTER the tool block
|
|
if len(result) != 4 {
|
|
t.Fatalf("expected 4 messages (assistant + 2 tool + synthetic user), got %d", len(result))
|
|
}
|
|
if result[3].Role != "user" {
|
|
t.Fatalf("result[3] expected user, got %q", result[3].Role)
|
|
}
|
|
if len(result[3].Media) != 1 || !strings.HasPrefix(result[3].Media[0], "data:image/png;base64,") {
|
|
t.Fatal("expected synthetic user message to contain base64 image")
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_OversizedImageSkipsBase64KeepsPathTag(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
bigPath := filepath.Join(dir, "big.png")
|
|
// Write PNG header + padding to exceed limit
|
|
data := make([]byte, 1024+1) // 1KB + 1 byte
|
|
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
|
|
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "hi", Media: []string{ref}},
|
|
}
|
|
// Use a tiny limit (1KB) so the file is oversized
|
|
result := resolveMediaRefs(messages, store, 1024)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
|
|
}
|
|
localPath, _, _ := store.ResolveWithMeta(ref)
|
|
expected := "hi [image:" + localPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_UnknownTypeInjectsPath(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
txtPath := filepath.Join(dir, "readme.txt")
|
|
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "hi", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media entries, got %d", len(result[0].Media))
|
|
}
|
|
expected := "hi [file:" + txtPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
|
|
}
|
|
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
|
|
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
pngPath := filepath.Join(dir, "test.png")
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
|
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
|
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
|
}
|
|
os.WriteFile(pngPath, pngHeader, 0o644)
|
|
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
|
|
|
original := []providers.Message{
|
|
{Role: "user", Content: "hi", Media: []string{ref}},
|
|
}
|
|
originalRef := original[0].Media[0]
|
|
|
|
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
|
|
|
|
if original[0].Media[0] != originalRef {
|
|
t.Fatal("resolveMediaRefs mutated original message slice")
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
// File with JPEG content but stored with explicit content type
|
|
jpegPath := filepath.Join(dir, "photo")
|
|
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
|
|
os.WriteFile(jpegPath, jpegHeader, 0o644)
|
|
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "hi", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
|
|
}
|
|
localPath, _, _ := store.ResolveWithMeta(ref)
|
|
expectedContent := "hi [image:" + localPath + "]"
|
|
if result[0].Content != expectedContent {
|
|
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_PDFInjectsFilePath(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
pdfPath := filepath.Join(dir, "report.pdf")
|
|
// PDF magic bytes
|
|
os.WriteFile(pdfPath, []byte("%PDF-1.4 test content"), 0o644)
|
|
ref, _ := store.Store(pdfPath, media.MediaMeta{ContentType: "application/pdf"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "report.pdf [file]", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media (non-image), got %d", len(result[0].Media))
|
|
}
|
|
expected := "report.pdf [file:" + pdfPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_AudioInjectsAudioPath(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
oggPath := filepath.Join(dir, "voice.ogg")
|
|
os.WriteFile(oggPath, []byte("fake audio"), 0o644)
|
|
ref, _ := store.Store(oggPath, media.MediaMeta{ContentType: "audio/ogg"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "voice.ogg [audio]", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media, got %d", len(result[0].Media))
|
|
}
|
|
expected := "voice.ogg [audio:" + oggPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_VideoInjectsVideoPath(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
mp4Path := filepath.Join(dir, "clip.mp4")
|
|
os.WriteFile(mp4Path, []byte("fake video"), 0o644)
|
|
ref, _ := store.Store(mp4Path, media.MediaMeta{ContentType: "video/mp4"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "clip.mp4 [video]", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media, got %d", len(result[0].Media))
|
|
}
|
|
expected := "clip.mp4 [video:" + mp4Path + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_NoGenericTagAppendsPath(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
csvPath := filepath.Join(dir, "data.csv")
|
|
os.WriteFile(csvPath, []byte("a,b,c"), 0o644)
|
|
ref, _ := store.Store(csvPath, media.MediaMeta{ContentType: "text/csv"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "here is my data", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
expected := "here is my data [file:" + csvPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestInjectPathTags_HandlesVariousChannelPlaceholders(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
content string
|
|
tag string
|
|
want string
|
|
}{
|
|
// Telegram / Feishu format
|
|
{"image_photo", "[image: photo]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
|
// WeCom / WeChat / Line format
|
|
{"bare_image", "[image]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
|
// QQ / Discord format with filename
|
|
{"image_filename", "[image: pic.jpg]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
|
{"audio_with_filename", "[audio: voice.m4a]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
|
|
{"bare_audio", "[audio]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
|
|
{"bare_video", "[video]", "[video:/tmp/v.mp4]", "[video:/tmp/v.mp4]"},
|
|
{"bare_file", "[file]", "[file:/tmp/f.pdf]", "[file:/tmp/f.pdf]"},
|
|
// Mixed surrounding text
|
|
{
|
|
"with_text",
|
|
"hello [image] world",
|
|
"[image:/tmp/p.png]",
|
|
"hello [image:/tmp/p.png] world",
|
|
},
|
|
// No placeholder — append
|
|
{"no_placeholder", "hello world", "[image:/tmp/p.png]", "hello world [image:/tmp/p.png]"},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := injectPathTags(tc.content, []string{tc.tag})
|
|
if got != tc.want {
|
|
t.Errorf("expected %q, got %q", tc.want, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInjectPathTags_DoesNotReplacePathTag(t *testing.T) {
|
|
// If content already contains a path tag, we must not touch it.
|
|
content := "see [image:/already/placed.png] thanks"
|
|
got := injectPathTags(content, []string{"[image:/new/path.png]"})
|
|
want := "see [image:/already/placed.png] thanks [image:/new/path.png]"
|
|
if got != want {
|
|
t.Fatalf("expected %q, got %q", want, got)
|
|
}
|
|
}
|
|
|
|
func TestInjectPathTags_PrependsForJSONContent(t *testing.T) {
|
|
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
|
|
got := injectPathTags(jsonContent, []string{"[image:/tmp/photo.png]"})
|
|
want := "[image:/tmp/photo.png]\n" + jsonContent
|
|
if got != want {
|
|
t.Fatalf("expected tag prepended to JSON, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestInjectPathTags_BracketTextNotTreatedAsJSON(t *testing.T) {
|
|
content := "[update] see attached report"
|
|
got := injectPathTags(content, []string{"[file:/tmp/report.pdf]"})
|
|
want := "[update] see attached report [file:/tmp/report.pdf]"
|
|
if got != want {
|
|
t.Fatalf("expected tag appended to bracket text, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_JSONContentPrependsPathTag(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
pngPath := filepath.Join(dir, "card_img.png")
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
|
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
|
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
|
}
|
|
os.WriteFile(pngPath, pngHeader, 0o644)
|
|
ref, _ := store.Store(pngPath, media.MediaMeta{ContentType: "image/png"}, "test")
|
|
|
|
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: jsonContent, Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
want := "[image:" + pngPath + "]\n" + jsonContent
|
|
if result[0].Content != want {
|
|
t.Fatalf("expected path tag prepended to JSON content, got %q", result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_EmptyContentGetsPathTag(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
docPath := filepath.Join(dir, "doc.docx")
|
|
os.WriteFile(docPath, []byte("fake docx"), 0o644)
|
|
docxMIME := "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
ref, _ := store.Store(docPath, media.MediaMeta{ContentType: docxMIME}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "", Media: []string{ref}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
expected := "[file:" + docPath + "]"
|
|
if result[0].Content != expected {
|
|
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
|
|
store := media.NewFileMediaStore()
|
|
dir := t.TempDir()
|
|
|
|
pngPath := filepath.Join(dir, "photo.png")
|
|
pngHeader := []byte{
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
|
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
|
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
|
}
|
|
os.WriteFile(pngPath, pngHeader, 0o644)
|
|
imgRef, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
|
|
|
pdfPath := filepath.Join(dir, "report.pdf")
|
|
os.WriteFile(pdfPath, []byte("%PDF-1.4 test"), 0o644)
|
|
fileRef, _ := store.Store(pdfPath, media.MediaMeta{ContentType: "application/pdf"}, "test")
|
|
|
|
messages := []providers.Message{
|
|
{Role: "user", Content: "check these [file]", Media: []string{imgRef, fileRef}},
|
|
}
|
|
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
|
|
|
if len(result[0].Media) != 0 {
|
|
t.Fatalf("expected 0 media (all types use path tags), got %d", len(result[0].Media))
|
|
}
|
|
imgLocalPath, _, _ := store.ResolveWithMeta(imgRef)
|
|
pdfLocalPath, _, _ := store.ResolveWithMeta(fileRef)
|
|
expectedContent := "check these [file:" + pdfLocalPath + "] [image:" + imgLocalPath + "]"
|
|
if result[0].Content != expectedContent {
|
|
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
|
}
|
|
}
|
|
|
|
// --- Native search helper tests ---
|
|
|
|
type nativeSearchProvider struct {
|
|
supported bool
|
|
}
|
|
|
|
func (p *nativeSearchProvider) Chat(
|
|
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
|
model string, opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{Content: "ok"}, nil
|
|
}
|
|
|
|
func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" }
|
|
|
|
func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported }
|
|
|
|
type plainProvider struct{}
|
|
|
|
func (p *plainProvider) Chat(
|
|
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
|
model string, opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{Content: "ok"}, nil
|
|
}
|
|
|
|
func (p *plainProvider) GetDefaultModel() string { return "test-model" }
|
|
|
|
func TestIsNativeSearchProvider_Supported(t *testing.T) {
|
|
if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) {
|
|
t.Fatal("expected true for provider that supports native search")
|
|
}
|
|
}
|
|
|
|
func TestIsNativeSearchProvider_NotSupported(t *testing.T) {
|
|
if isNativeSearchProvider(&nativeSearchProvider{supported: false}) {
|
|
t.Fatal("expected false for provider that does not support native search")
|
|
}
|
|
}
|
|
|
|
func TestIsNativeSearchProvider_NoInterface(t *testing.T) {
|
|
if isNativeSearchProvider(&plainProvider{}) {
|
|
t.Fatal("expected false for provider that does not implement NativeSearchCapable")
|
|
}
|
|
}
|
|
|
|
func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) {
|
|
defs := []providers.ToolDefinition{
|
|
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}},
|
|
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
|
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
|
}
|
|
result := filterClientWebSearch(defs)
|
|
if len(result) != 2 {
|
|
t.Fatalf("len(result) = %d, want 2", len(result))
|
|
}
|
|
for _, td := range result {
|
|
if td.Function.Name == "web_search" {
|
|
t.Fatal("web_search should be filtered out")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFilterClientWebSearch_NoWebSearch(t *testing.T) {
|
|
defs := []providers.ToolDefinition{
|
|
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
|
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
|
}
|
|
result := filterClientWebSearch(defs)
|
|
if len(result) != 2 {
|
|
t.Fatalf("len(result) = %d, want 2", len(result))
|
|
}
|
|
}
|
|
|
|
func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
|
|
result := filterClientWebSearch(nil)
|
|
if len(result) != 0 {
|
|
t.Fatalf("len(result) = %d, want 0", len(result))
|
|
}
|
|
}
|
|
|
|
type overflowProvider struct {
|
|
calls int
|
|
lastMessages []providers.Message
|
|
chatFunc func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any) (*providers.LLMResponse, error)
|
|
}
|
|
|
|
func (p *overflowProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.calls++
|
|
p.lastMessages = append([]providers.Message(nil), messages...)
|
|
|
|
if p.chatFunc != nil {
|
|
return p.chatFunc(ctx, messages, tools, model, opts)
|
|
}
|
|
|
|
if p.calls == 1 {
|
|
return nil, errors.New("context_window_exceeded")
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "Recovered from overflow",
|
|
}, nil
|
|
}
|
|
|
|
func (p *overflowProvider) GetDefaultModel() string {
|
|
return "test-model"
|
|
}
|
|
|
|
func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
|
|
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
_ = cfg
|
|
|
|
provider := &overflowProvider{}
|
|
al.registry = NewAgentRegistry(al.cfg, provider)
|
|
|
|
sessionKey := "agent:main:test-session"
|
|
agent := al.GetRegistry().GetDefaultAgent()
|
|
|
|
for i := 0; i < 5; i++ {
|
|
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
|
|
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
|
|
}
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "test",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
SessionKey: "test-session",
|
|
Content: "trigger recovery",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if response != "Recovered from overflow" {
|
|
t.Fatalf("response = %q, want %q", response, "Recovered from overflow")
|
|
}
|
|
|
|
if provider.calls != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
|
|
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
_ = cfg
|
|
|
|
provider := &overflowProvider{}
|
|
al.registry = NewAgentRegistry(al.cfg, provider)
|
|
|
|
recoveryMsg := "error: status 400: context_window_exceeded"
|
|
|
|
provider.chatFunc = func(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
if provider.calls == 1 {
|
|
return nil, errors.New(recoveryMsg)
|
|
}
|
|
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
|
|
}
|
|
|
|
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
|
Channel: "test",
|
|
ChatID: "chat1",
|
|
SenderID: "user1",
|
|
Content: "hello",
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("processMessage() error = %v", err)
|
|
}
|
|
if !strings.Contains(response, "Anthropic recovery success") {
|
|
t.Fatalf("response = %q, want success message", response)
|
|
}
|
|
if provider.calls != 2 {
|
|
t.Fatalf("expected 2 calls for retry, got %d", provider.calls)
|
|
}
|
|
}
|
|
|
|
func TestParallelMessageProcessing_DifferentSessionsProcessedConcurrently(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
// Track concurrent executions using a unique ID per turn
|
|
var mu sync.Mutex
|
|
activeTurns := make(map[string]bool)
|
|
maxConcurrent := 0
|
|
turnCounter := 0
|
|
var wg sync.WaitGroup
|
|
wg.Add(3) // Wait for 3 turns to complete
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
MaxParallelTurns: 3, // Allow up to 3 concurrent turns
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"chat"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
defer msgBus.Close()
|
|
|
|
// Create a slow mock provider that tracks concurrency
|
|
provider := &concurrentMockProvider{
|
|
responseFunc: func(callID int) string {
|
|
mu.Lock()
|
|
turnCounter++
|
|
turnID := fmt.Sprintf("turn-%d", turnCounter)
|
|
activeTurns[turnID] = true
|
|
currentActive := len(activeTurns)
|
|
if currentActive > maxConcurrent {
|
|
maxConcurrent = currentActive
|
|
}
|
|
mu.Unlock()
|
|
|
|
// Simulate some processing time
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
mu.Lock()
|
|
delete(activeTurns, turnID)
|
|
mu.Unlock()
|
|
|
|
wg.Done()
|
|
return fmt.Sprintf("Response %s", turnID)
|
|
},
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
defer al.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Start the agent loop
|
|
go func() {
|
|
if err := al.Run(ctx); err != nil {
|
|
t.Logf("Agent loop error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Give the loop time to start
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send 3 messages from different sessions
|
|
sessions := []string{"user1", "user2", "user3"}
|
|
for i, session := range sessions {
|
|
msg := bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: fmt.Sprintf("chat%d", i),
|
|
ChatType: "direct",
|
|
SenderID: session,
|
|
},
|
|
Channel: "telegram",
|
|
ChatID: fmt.Sprintf("chat%d", i),
|
|
SenderID: session,
|
|
Content: fmt.Sprintf("Hello from %s", session),
|
|
}
|
|
if err := msgBus.PublishInbound(context.Background(), msg); err != nil {
|
|
t.Fatalf("PublishInbound failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Wait for all turns to complete with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// All turns completed successfully
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for turns to complete")
|
|
}
|
|
|
|
// Verify that we had concurrent executions
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if maxConcurrent < 2 {
|
|
t.Errorf("Expected at least 2 concurrent executions, got max %d", maxConcurrent)
|
|
}
|
|
|
|
t.Logf("Maximum concurrent executions: %d", maxConcurrent)
|
|
}
|
|
|
|
func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
var mu sync.Mutex
|
|
turnIDs := make(map[string]bool)
|
|
var wg sync.WaitGroup
|
|
wg.Add(1) // Only 1 turn should be created for same session
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
MaxParallelTurns: 3,
|
|
},
|
|
},
|
|
Session: config.SessionConfig{
|
|
Dimensions: []string{"chat"},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
defer msgBus.Close()
|
|
|
|
al := NewAgentLoop(cfg, msgBus, &concurrentMockProvider{
|
|
responseFunc: func(callID int) string {
|
|
wg.Done()
|
|
return "ok"
|
|
},
|
|
})
|
|
defer al.Close()
|
|
|
|
sub := al.SubscribeEvents(64)
|
|
|
|
go func() {
|
|
for evt := range sub.C {
|
|
if evt.Kind == EventKindTurnStart {
|
|
mu.Lock()
|
|
turnIDs[evt.Meta.TurnID] = true
|
|
mu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
go func() {
|
|
if err := al.Run(ctx); err != nil {
|
|
t.Logf("Agent loop error: %v", err)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send 3 messages from the SAME session - only one turn should be created;
|
|
// subsequent messages should be enqueued to the steering queue and processed
|
|
// within the same turn (not as separate concurrent turns).
|
|
for i := 0; i < 3; i++ {
|
|
msg := bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "chat1",
|
|
ChatType: "direct",
|
|
SenderID: "user1",
|
|
},
|
|
Channel: "telegram",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: fmt.Sprintf("Message %d", i+1),
|
|
}
|
|
if err := msgBus.PublishInbound(context.Background(), msg); err != nil {
|
|
t.Fatalf("PublishInbound failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Wait for turn to complete with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Turn completed successfully
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for turn to complete")
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Only 1 turn ID should have been created — proving messages were
|
|
// serialized into a single turn rather than spawning concurrent turns.
|
|
if len(turnIDs) != 1 {
|
|
t.Errorf("Expected 1 turn (others queued to steering), got %d: %v", len(turnIDs), turnIDs)
|
|
}
|
|
}
|
|
|
|
// concurrentMockProvider is a mock provider that allows tracking concurrency
|
|
type concurrentMockProvider struct {
|
|
responseFunc func(callID int) string
|
|
}
|
|
|
|
func (p *concurrentMockProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
// Use an atomic counter to assign unique call IDs for concurrency tracking.
|
|
// This avoids relying on sessionKey derivation from message content, which
|
|
// is not deterministic across concurrent calls.
|
|
response := "Mock response"
|
|
if p.responseFunc != nil {
|
|
response = p.responseFunc(len(messages))
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: response,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (p *concurrentMockProvider) GetDefaultModel() string {
|
|
return "test-model"
|
|
}
|