From df4f322f09239595bb9ddf163c885f936efff045 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 12:05:28 +0100 Subject: [PATCH] fix(tool): route binary outputs through the media pipeline. --- pkg/agent/loop.go | 145 +++++++++++++---- pkg/agent/loop_test.go | 296 +++++++++++++++++++++++++++++++++++ pkg/channels/manager.go | 42 ++++- pkg/channels/manager_test.go | 70 +++++++++ pkg/tools/mcp_tool.go | 280 +++++++++++++++++++++++++++++++-- pkg/tools/mcp_tool_test.go | 144 +++++++++++++++++ pkg/tools/normalization.go | 292 ++++++++++++++++++++++++++++++++++ pkg/tools/registry.go | 39 ++++- pkg/tools/registry_test.go | 111 +++++++++++++ pkg/tools/result.go | 58 ++++++- pkg/tools/result_test.go | 39 +++++ pkg/tools/send_file.go | 2 +- pkg/tools/send_file_test.go | 3 + pkg/tools/toolloop.go | 5 +- 14 files changed, 1462 insertions(+), 64 deletions(-) create mode 100644 pkg/tools/normalization.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ed5c73afc..2b6672f9c 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -492,13 +492,13 @@ func (al *AgentLoop) GetConfig() *config.Config { func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s - // Propagate store to send_file tools in all agents. + // Propagate store to all registered tools that can emit media. registry := al.GetRegistry() - registry.ForEachTool("send_file", func(t tools.Tool) { - if sf, ok := t.(*tools.SendFileTool); ok { - sf.SetMediaStore(s) + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { + agent.Tools.SetMediaStore(s) } - }) + } } // SetTranscriber injects a voice transcriber for agent-level audio transcription. @@ -926,13 +926,26 @@ func (al *AgentLoop) runAgentLoop( agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 3. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) + finalContent, iteration, responseHandled, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } - // If last tool had ForUser content and we already sent it, we might not need to send final response - // This is controlled by the tool's Silent flag and ForUser content + if responseHandled { + agent.Sessions.Save(opts.SessionKey) + + if opts.EnableSummary { + al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) + } + + logger.InfoCF("agent", "Response already handled by tool output", + map[string]any{ + "agent_id": agent.ID, + "session_key": opts.SessionKey, + "iterations": iteration, + }) + return "", nil + } // 4. Handle empty response if finalContent == "" { @@ -1030,14 +1043,57 @@ func (al *AgentLoop) handleReasoning( } } +const handledToolResponseSummary = "Requested output delivered via tool attachment." + +func (al *AgentLoop) buildOutboundMediaMessage( + channel string, + chatID string, + refs []string, +) bus.OutboundMediaMessage { + parts := make([]bus.MediaPart, 0, len(refs)) + for _, ref := range refs { + part := bus.MediaPart{Ref: ref} + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + return bus.OutboundMediaMessage{ + Channel: channel, + ChatID: chatID, + Parts: parts, + } +} + +func (al *AgentLoop) buildArtifactTags(refs []string) []string { + if al.mediaStore == nil || len(refs) == 0 { + return nil + } + + tags := make([]string, 0, len(refs)) + for _, ref := range refs { + localPath, meta, err := al.mediaStore.ResolveWithMeta(ref) + if err != nil { + continue + } + mime := detectMIME(localPath, meta) + tags = append(tags, buildPathTag(mime, localPath)) + } + return tags +} + // runLLMIteration executes the LLM call loop with tool handling. -// Returns (finalContent, iteration, error). +// Returns (finalContent, iteration, responseHandled, error). func (al *AgentLoop) runLLMIteration( ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions, -) (string, int, error) { +) (string, int, bool, error) { iteration := 0 var finalContent string @@ -1240,7 +1296,7 @@ func (al *AgentLoop) runLLMIteration( "model": activeModel, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) + return "", iteration, false, fmt.Errorf("LLM call failed after retries: %w", err) } go al.handleReasoning( @@ -1401,10 +1457,7 @@ func (al *AgentLoop) runLLMIteration( } // Determine content for the agent loop (ForLLM or error). - content := result.ForLLM - if content == "" && result.Err != nil { - content = result.Err.Error() - } + content := result.ContentForLLM() if content == "" { return } @@ -1439,8 +1492,14 @@ func (al *AgentLoop) runLLMIteration( } wg.Wait() + allResponsesHandled := len(agentResults) > 0 + // Process results in original order (send to user, save to session) for _, r := range agentResults { + if !r.result.ResponseHandled { + allResponsesHandled = false + } + // Send ForUser content to user immediately if not Silent if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ @@ -1455,32 +1514,33 @@ func (al *AgentLoop) runLLMIteration( }) } - // If tool returned media refs, publish them as outbound media + // If tool returned media refs, publish them as outbound media only when the + // tool explicitly marked the user-visible delivery as already handled. if len(r.result.Media) > 0 { - parts := make([]bus.MediaPart, 0, len(r.result.Media)) - for _, ref := range r.result.Media { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) + outboundMedia := al.buildOutboundMediaMessage(opts.Channel, opts.ChatID, r.result.Media) + if r.result.ResponseHandled { + if al.channelManager != nil { + if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { + allResponsesHandled = false + logger.WarnCF("agent", "Synchronous media send failed, falling back to bus delivery", + map[string]any{ + "agent_id": agent.ID, + "tool": r.tc.Name, + "error": err.Error(), + }) + al.bus.PublishOutboundMedia(ctx, outboundMedia) } + } else { + al.bus.PublishOutboundMedia(ctx, outboundMedia) } - parts = append(parts, part) } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, - Parts: parts, - }) } // Determine content for LLM based on tool result - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() + if len(r.result.Media) > 0 && !r.result.ResponseHandled { + r.result.ArtifactTags = al.buildArtifactTags(r.result.Media) } + contentForLLM := r.result.ContentForLLM() toolResultMsg := providers.Message{ Role: "tool", @@ -1493,6 +1553,23 @@ func (al *AgentLoop) runLLMIteration( agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } + if allResponsesHandled { + summaryMsg := providers.Message{ + Role: "assistant", + Content: handledToolResponseSummary, + } + messages = append(messages, summaryMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, summaryMsg) + + logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "tool_count": len(agentResults), + }) + return "", iteration, true, nil + } + // Tick down TTL of discovered tools after processing tool results. // Only reached when tool calls were made (the loop continues); // the break on no-tool-call responses skips this. @@ -1505,7 +1582,7 @@ func (al *AgentLoop) runLLMIteration( }) } - return finalContent, iteration, nil + return finalContent, iteration, false, nil } // selectCandidates returns the model candidates and resolved model name to use diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 28eab03db..ea63d9634 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -298,6 +298,152 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { } } +func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + + imagePath := filepath.Join(tmpDir, "screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response when media tool already handled delivery, got %q", response) + } + if provider.calls != 1 { + t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls) + } + if len(provider.toolCounts) != 1 { + t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts)) + } + if provider.toolCounts[0] == 0 { + t.Fatal("expected tools to be available on the first LLM call") + } + + select { + case mediaMsg := <-msgBus.OutboundMediaChan(): + if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { + t.Fatalf("unexpected outbound media target: %+v", mediaMsg) + } + if len(mediaMsg.Parts) != 1 { + t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) + } + default: + t.Fatal("expected outbound media message to be published") + } + + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + sessionKey := resolveScopeKey(route, "") + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) == 0 { + t.Fatal("expected session history to be saved") + } + last := history[len(history)-1] + if last.Role != "assistant" || last.Content != handledToolResponseSummary { + t.Fatalf("expected handled assistant summary in history, got %+v", last) + } +} + +func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { + tmpDir := t.TempDir() + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + + msgBus := bus.NewMessageBus() + provider := &artifactThenSendProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + + mediaDir := media.TempDir() + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + t.Fatalf("MkdirAll(mediaDir) error = %v", err) + } + imagePath := filepath.Join(mediaDir, "artifact-screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&mediaArtifactTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response after send_file handled delivery, got %q", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls) + } + + select { + case mediaMsg := <-msgBus.OutboundMediaChan(): + if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { + t.Fatalf("unexpected outbound media target: %+v", mediaMsg) + } + if len(mediaMsg.Parts) != 1 { + t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) + } + default: + t.Fatal("expected outbound media from send_file") + } +} + // TestAgentLoop_GetStartupInfo verifies startup info contains tools func TestAgentLoop_GetStartupInfo(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -420,6 +566,98 @@ func (m *countingMockProvider) GetDefaultModel() string { return "counting-mock-model" } +type handledMediaProvider struct { + calls int + toolCounts []int +} + +func (m *handledMediaProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + m.toolCounts = append(m.toolCounts, len(tools)) + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media", + Type: "function", + Name: "handled_media_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *handledMediaProvider) GetDefaultModel() string { + return "handled-media-model" +} + +type artifactThenSendProvider struct { + calls int +} + +func (m *artifactThenSendProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_artifact_media", + Type: "function", + Name: "media_artifact_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + var artifactPath string + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != "tool" { + continue + } + start := strings.Index(messages[i].Content, "[file:") + if start < 0 { + continue + } + rest := messages[i].Content[start+len("[file:"):] + end := strings.Index(rest, "]") + if end < 0 { + continue + } + artifactPath = rest[:end] + break + } + if artifactPath == "" { + return nil, fmt.Errorf("provider did not receive artifact path in tool result") + } + + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_send_file", + Type: "function", + Name: "send_file", + Arguments: map[string]any{"path": artifactPath}, + }}, + }, nil +} + +func (m *artifactThenSendProvider) GetDefaultModel() string { + return "artifact-then-send-model" +} + type toolLimitOnlyProvider struct{} func (m *toolLimitOnlyProvider) Chat( @@ -465,6 +703,64 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool return tools.SilentResult("Custom tool executed") } +type handledMediaTool struct { + store media.MediaStore + path string +} + +func (m *handledMediaTool) Name() string { return "handled_media_tool" } +func (m *handledMediaTool) Description() string { + return "Returns a media attachment and fully handles the user response" +} + +func (m *handledMediaTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_tool", + }, "test:handled_media") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + +type mediaArtifactTool struct { + store media.MediaStore + path string +} + +func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" } +func (m *mediaArtifactTool) Description() string { + return "Returns a media artifact that the agent can forward or save later" +} + +func (m *mediaArtifactTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:media_artifact_tool", + }, "test:media_artifact") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Artifact created.", []string{ref}) +} + type toolLimitTestTool struct{} func (m *toolLimitTestTool) Name() string { diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index ff3fa399c..f4a64807e 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -771,7 +771,7 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor if !ok { return } - m.sendMediaWithRetry(ctx, name, w, msg) + _ = m.sendMediaWithRetry(ctx, name, w, msg) case <-ctx.Done(): return } @@ -779,26 +779,31 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor } // sendMediaWithRetry sends a media message through the channel with rate limiting and -// retry logic. If the channel does not implement MediaSender, it silently skips. -func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { +// retry logic. It returns nil on success, or the last error after retries. +func (m *Manager) sendMediaWithRetry( + ctx context.Context, + name string, + w *channelWorker, + msg bus.OutboundMediaMessage, +) error { ms, ok := w.ch.(MediaSender) if !ok { logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ "channel": name, }) - return + return nil } // Rate limit: wait for token if err := w.limiter.Wait(ctx); err != nil { - return + return err } var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = ms.SendMedia(ctx, msg) if lastErr == nil { - return + return nil } // Permanent failures — don't retry @@ -817,7 +822,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe case <-time.After(rateLimitDelay): continue case <-ctx.Done(): - return + return ctx.Err() } } @@ -826,7 +831,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe select { case <-time.After(backoff): case <-ctx.Done(): - return + return ctx.Err() } } @@ -837,6 +842,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe "error": lastErr.Error(), "retries": maxRetries, }) + return lastErr } // runTTLJanitor periodically scans the typingStops and placeholders maps @@ -1029,6 +1035,26 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro return nil } +// SendMedia sends outbound media synchronously through the channel worker's +// rate limiter and retry logic. It blocks until the media is delivered (or all +// retries are exhausted), which preserves ordering when later agent behavior +// depends on actual media delivery. +func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + return fmt.Errorf("channel %s not found", msg.Channel) + } + if !wExists || w == nil { + return fmt.Errorf("channel %s has no active worker", msg.Channel) + } + + return m.sendMediaWithRetry(ctx, msg.Channel, w, msg) +} + func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() _, exists := m.channels[channelName] diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 7dfec9ebf..6a5dd7e30 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -43,6 +43,20 @@ func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, conten return nil } +type mockMediaChannel struct { + mockChannel + sendMediaFn func(ctx context.Context, msg bus.OutboundMediaMessage) error + sentMediaMessages []bus.OutboundMediaMessage +} + +func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.sentMediaMessages = append(m.sentMediaMessages, msg) + if m.sendMediaFn != nil { + return m.sendMediaFn(ctx, msg) + } + return nil +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -208,6 +222,62 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { } } +func TestSendMedia_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if callCount != 1 { + t.Fatalf("expected 1 SendMedia call, got %d", callCount) + } +} + +func TestSendMedia_PropagatesFailure(t *testing.T) { + m := newTestManager() + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + return fmt.Errorf("bad upload: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err == nil { + t.Fatal("expected SendMedia to return error") + } + if !errors.Is(err, ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } +} + func TestSendWithRetry_UnknownError(t *testing.T) { m := newTestManager() var callCount int diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go index 6e53cf354..5bffb4e89 100644 --- a/pkg/tools/mcp_tool.go +++ b/pkg/tools/mcp_tool.go @@ -5,9 +5,13 @@ import ( "encoding/json" "fmt" "hash/fnv" + "os" "strings" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MCPManager defines the interface for MCP manager operations @@ -25,6 +29,7 @@ type MCPTool struct { manager MCPManager serverName string tool *mcp.Tool + mediaStore media.MediaStore } // NewMCPTool creates a new MCP tool wrapper @@ -36,6 +41,10 @@ func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool } } +func (t *MCPTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + // sanitizeIdentifierComponent normalizes a string so it can be safely used // as part of a tool/function identifier for downstream providers. // It: @@ -218,13 +227,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult WithError(fmt.Errorf("MCP tool error: %s", errMsg)) } - // Extract text content from result - output := extractContentText(result.Content) - - return &ToolResult{ - ForLLM: output, - IsError: false, - } + return t.normalizeResultContent(ctx, result.Content) } // extractContentText extracts text from MCP content array @@ -233,14 +236,269 @@ func extractContentText(content []mcp.Content) string { for _, c := range content { switch v := c.(type) { case *mcp.TextContent: - parts = append(parts, v.Text) + parts = append(parts, sanitizeToolLLMContent(v.Text)) case *mcp.ImageContent: - // For images, just indicate that an image was returned - parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.AudioContent: + parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.ResourceLink: + parts = append(parts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + parts = append(parts, summarizeEmbeddedResource(v)) default: // For other content types, use string representation parts = append(parts, fmt.Sprintf("[Content: %T]", v)) } } - return strings.Join(parts, "\n") + return sanitizeToolLLMContent(strings.Join(parts, "\n")) +} + +func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult { + llmParts := make([]string, 0, len(content)) + mediaRefs := make([]string, 0, len(content)) + + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + text := strings.TrimSpace(sanitizeToolLLMContent(v.Text)) + if text != "" { + llmParts = append(llmParts, text) + } + case *mcp.ImageContent: + ref, note := t.storeBinaryContent( + ctx, + "image", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.AudioContent: + ref, note := t.storeBinaryContent( + ctx, + "audio", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.ResourceLink: + llmParts = append(llmParts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + ref, note := t.storeEmbeddedResource(ctx, v) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + default: + llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v)) + } + } + + result := &ToolResult{ + ForLLM: strings.Join(compactStrings(llmParts), "\n"), + Media: mediaRefs, + } + return result +} + +func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) { + if content == nil || content.Resource == nil { + return "", "[MCP returned an embedded resource without data.]" + } + + resource := content.Resource + if len(resource.Blob) > 0 { + return t.storeBinaryContent( + ctx, + "resource", + normalizedMIMEType(resource.MIMEType), + resource.Blob, + content.Annotations, + ) + } + + if strings.TrimSpace(resource.Text) != "" { + return "", sanitizeToolLLMContent(resource.Text) + } + + return "", summarizeEmbeddedResource(content) +} + +func (t *MCPTool) storeBinaryContent( + ctx context.Context, + kind string, + mimeType string, + data []byte, + annotations *mcp.Annotations, +) (string, string) { + if len(data) == 0 { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType) + } + if !annotationsAllowUser(annotations) { + return "", fmt.Sprintf( + "[MCP returned %s content (%s) for non-user audience; omitted from model context.]", + kind, + mimeType, + ) + } + if t.mediaStore == nil { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]", + kind, + mimeType, + ) + } + + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) + if channel == "" || chatID == "" { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because no target chat was available.]", + kind, + mimeType, + ) + } + + dir := media.TempDir() + if err := os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext) + if err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + scope := fmt.Sprintf( + "tool:mcp:%s:%s:%s:%d", + sanitizeIdentifierComponent(t.serverName), + channel, + chatID, + time.Now().UnixNano(), + ) + filename := fmt.Sprintf( + "%s_%s%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ext, + ) + + ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf( + "tool:mcp:%s:%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf( + "[MCP returned %s content (%s) but it could not be registered as media.]", + kind, + mimeType, + ) + } + + return ref, fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]", + kind, + mimeType, + ) +} + +func summarizeResourceLink(content *mcp.ResourceLink) string { + if content == nil { + return "[MCP returned an empty resource link.]" + } + + parts := []string{"[MCP returned resource link"} + if content.Name != "" { + parts = append(parts, fmt.Sprintf("name=%q", content.Name)) + } + if content.URI != "" { + parts = append(parts, fmt.Sprintf("uri=%q", content.URI)) + } + if content.MIMEType != "" { + parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType)) + } + if content.Description != "" { + desc := strings.TrimSpace(content.Description) + if len(desc) > 200 { + desc = desc[:200] + "..." + } + parts = append(parts, fmt.Sprintf("description=%q", desc)) + } + return strings.Join(parts, ", ") + "]" +} + +func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string { + if content == nil || content.Resource == nil { + return "[MCP returned an embedded resource.]" + } + + resource := content.Resource + if resource.URI != "" { + return fmt.Sprintf( + "[MCP returned embedded resource %q (%s).]", + resource.URI, + normalizedMIMEType(resource.MIMEType), + ) + } + return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType)) +} + +func annotationsAllowUser(annotations *mcp.Annotations) bool { + if annotations == nil || len(annotations.Audience) == 0 { + return true + } + for _, audience := range annotations.Audience { + if strings.EqualFold(string(audience), "user") { + return true + } + } + return false +} + +func normalizedMIMEType(mimeType string) string { + if strings.TrimSpace(mimeType) == "" { + return "application/octet-stream" + } + return mimeType +} + +func compactStrings(parts []string) []string { + compact := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + compact = append(compact, part) + } + return compact } diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go index 95bb0f992..8bbac3bc7 100644 --- a/pkg/tools/mcp_tool_test.go +++ b/pkg/tools/mcp_tool_test.go @@ -3,10 +3,14 @@ package tools import ( "context" "fmt" + "os" + "path/filepath" "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MockMCPManager is a mock implementation of MCPManager interface for testing @@ -490,3 +494,143 @@ func TestMCPTool_Parameters_MapSchema(t *testing.T) { t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) } } + +func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("fake-image-bytes"), + MIMEType: "image/png", + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if result.IsError { + t.Fatalf("expected success, got %q", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if result.ResponseHandled { + t.Fatal("expected MCP image artifact not to mark response as handled") + } + if !strings.Contains(result.ForLLM, "stored as a local media artifact") { + t.Fatalf("expected local media artifact note, got %q", result.ForLLM) + } + + path, meta, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if meta.ContentType != "image/png" { + t.Fatalf("expected image/png content type, got %q", meta.ContentType) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected png temp file, got %q", path) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "fake-image-bytes" { + t.Fatalf("expected stored media bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.EmbeddedResource{ + Resource: &mcp.ResourceContents{ + URI: "file:///tmp/report.png", + MIMEType: "image/png", + Blob: []byte("blob-bytes"), + }, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 1 { + t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media)) + } + path, _, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "blob-bytes" { + t.Fatalf("expected stored blob bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("assistant-only"), + MIMEType: "image/png", + Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}}, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 0 { + t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media)) + } + if !strings.Contains(result.ForLLM, "non-user audience") { + t.Fatalf("expected audience note, got %q", result.ForLLM) + } +} + +func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Repeat("QUJD", 400)}, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"}) + + result := mcpTool.Execute(context.Background(), nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/normalization.go b/pkg/tools/normalization.go new file mode 100644 index 000000000..3a76c5d92 --- /dev/null +++ b/pkg/tools/normalization.go @@ -0,0 +1,292 @@ +package tools + +import ( + "encoding/base64" + "fmt" + "mime" + "os" + "path/filepath" + "regexp" + "strings" + "time" + "unicode" + + "github.com/sipeed/picoclaw/pkg/media" +) + +const ( + largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]" + inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]" + inlineMediaStoredMessage = "[Tool returned inline media content (%s); omitted from model context and registered as a media attachment.]" +) + +var ( + inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`) + inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`) +) + +func normalizeToolResult( + result *ToolResult, + toolName string, + store media.MediaStore, + channel string, + chatID string, +) *ToolResult { + if result == nil { + return nil + } + + notes := make([]string, 0, 2) + seen := make(map[string]struct{}) + + if store != nil && channel != "" && chatID != "" { + var refs []string + var extractedNotes []string + + result.ForLLM, refs, extractedNotes = extractInlineMediaRefs( + result.ForLLM, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + + result.ForUser, refs, extractedNotes = extractInlineMediaRefs( + result.ForUser, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + } + + result.ForLLM = sanitizeToolLLMContent(result.ForLLM) + + if len(result.Media) > 0 && len(notes) > 0 { + if strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = strings.Join(notes, "\n") + } else { + result.ForLLM = strings.TrimSpace(result.ForLLM) + "\n" + strings.Join(notes, "\n") + } + } + if len(result.Media) > 0 && strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = "[Tool returned media content; omitted from model context and registered as a media attachment.]" + } + + return result +} + +func sanitizeToolLLMContent(text string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return text + } + if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) { + cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "") + cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "") + cleaned = strings.TrimSpace(cleaned) + if cleaned == "" { + return inlineMediaOmittedMessage + } + return cleaned + "\n" + inlineMediaOmittedMessage + } + if looksLikeLargeBase64Payload(trimmed) { + return largeBase64OmittedMessage + } + return text +} + +func looksLikeLargeBase64Payload(text string) bool { + trimmed := strings.TrimSpace(text) + if len(trimmed) < 1024 { + return false + } + + nonSpace := 0 + base64Like := 0 + spaceCount := 0 + + for _, r := range trimmed { + if unicode.IsSpace(r) { + spaceCount++ + continue + } + nonSpace++ + if (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '+' || r == '/' || r == '=' { + base64Like++ + } + } + + if nonSpace == 0 { + return false + } + + ratio := float64(base64Like) / float64(nonSpace) + return ratio >= 0.97 && spaceCount <= len(trimmed)/128 +} + +func extractInlineMediaRefs( + text string, + toolName string, + store media.MediaStore, + channel string, + chatID string, + seen map[string]struct{}, +) (cleaned string, refs []string, notes []string) { + cleaned = text + + matches := inlineMarkdownDataURLRe.FindAllStringSubmatch(cleaned, -1) + for _, match := range matches { + if len(match) < 2 { + continue + } + dataURL := match[1] + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, match[0], "") + } + + rawMatches := inlineRawDataURLRe.FindAllString(cleaned, -1) + for _, dataURL := range rawMatches { + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, dataURL, "") + } + + return strings.TrimSpace(cleaned), refs, notes +} + +func storeInlineDataURL( + toolName string, + store media.MediaStore, + channel string, + chatID string, + dataURL string, + seen map[string]struct{}, +) (ref string, note string) { + dataURL = strings.TrimSpace(dataURL) + if _, ok := seen[dataURL]; ok { + return "", "" + } + seen[dataURL] = struct{}{} + + if !strings.HasPrefix(strings.ToLower(dataURL), "data:") { + return "", "" + } + + comma := strings.IndexByte(dataURL, ',') + if comma <= 5 { + return "", "[Tool returned inline media content that could not be parsed.]" + } + + metaPart := dataURL[:comma] + payload := dataURL[comma+1:] + if !strings.Contains(strings.ToLower(metaPart), ";base64") { + return "", "[Tool returned inline media content that was not base64-encoded.]" + } + + mimeType := strings.TrimSpace(strings.TrimPrefix(metaPart, "data:")) + if semi := strings.IndexByte(mimeType, ';'); semi >= 0 { + mimeType = mimeType[:semi] + } + if mimeType == "" { + mimeType = "application/octet-stream" + } + + payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload) + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType) + } + + dir := media.TempDir() + if err = os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(decoded); err != nil { + tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + filename := sanitizeIdentifierComponent(toolName) + ext + scope := fmt.Sprintf( + "tool:inline:%s:%s:%s:%d", + sanitizeIdentifierComponent(toolName), + channel, + chatID, + time.Now().UnixNano(), + ) + + ref, err = store.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf("tool:inline:%s", sanitizeIdentifierComponent(toolName)), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType) + } + + return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType) +} + +func extensionForMIMEType(mimeType string) string { + if mimeType == "" { + return ".bin" + } + if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 { + return exts[0] + } + + switch strings.ToLower(mimeType) { + case "image/jpeg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "audio/wav", "audio/x-wav": + return ".wav" + case "audio/mpeg": + return ".mp3" + case "audio/ogg": + return ".ogg" + case "video/mp4": + return ".mp4" + default: + return filepath.Ext(mimeType) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 0b0f51cc1..902eb4423 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -19,9 +20,14 @@ type ToolEntry struct { } type ToolRegistry struct { - tools map[string]*ToolEntry - mu sync.RWMutex - version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + tools map[string]*ToolEntry + mu sync.RWMutex + version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + mediaStore media.MediaStore +} + +type mediaStoreAware interface { + SetMediaStore(store media.MediaStore) } func NewToolRegistry() *ToolRegistry { @@ -43,6 +49,9 @@ func (r *ToolRegistry) Register(tool Tool) { IsCore: true, TTL: 0, // Core tools do not use TTL } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name}) } @@ -61,10 +70,27 @@ func (r *ToolRegistry) RegisterHidden(tool Tool) { IsCore: false, TTL: 0, } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name}) } +// SetMediaStore injects a MediaStore into all registered tools that can +// consume it, and remembers it for future registrations. +func (r *ToolRegistry) SetMediaStore(store media.MediaStore) { + r.mu.Lock() + defer r.mu.Unlock() + + r.mediaStore = store + for _, entry := range r.tools { + if aware, ok := entry.Tool.(mediaStoreAware); ok { + aware.SetMediaStore(store) + } + } +} + // PromoteTools atomically sets the TTL for multiple non-core tools. // This prevents a concurrent TickTTL from decrementing between promotions. func (r *ToolRegistry) PromoteTools(names []string, ttl int) { @@ -230,6 +256,8 @@ func (r *ToolRegistry) ExecuteWithContext( } } + result = normalizeToolResult(result, name, r.mediaStore, channel, chatID) + duration := time.Since(start) // Log based on result type @@ -251,7 +279,7 @@ func (r *ToolRegistry) ExecuteWithContext( map[string]any{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result.ForLLM), + "result_length": len(result.ContentForLLM()), }) } @@ -346,7 +374,8 @@ func (r *ToolRegistry) Clone() *ToolRegistry { r.mu.RLock() defer r.mu.RUnlock() clone := &ToolRegistry{ - tools: make(map[string]*ToolEntry, len(r.tools)), + tools: make(map[string]*ToolEntry, len(r.tools)), + mediaStore: r.mediaStore, } for name, entry := range r.tools { clone.tools[name] = &ToolEntry{ diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 967758dfa..db52749f6 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -3,10 +3,13 @@ package tools import ( "context" "errors" + "os" + "path/filepath" "strings" "sync" "testing" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -46,6 +49,15 @@ func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string] return m.result } +type mockMediaStoreAwareTool struct { + mockRegistryTool + store media.MediaStore +} + +func (m *mockMediaStoreAwareTool) SetMediaStore(store media.MediaStore) { + m.store = store +} + // --- helpers --- func newMockTool(name, desc string) *mockRegistryTool { @@ -621,3 +633,102 @@ func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) { t.Errorf("expected 'success', got %q", result2.ForLLM) } } + +func TestToolRegistry_SetMediaStore_PropagatesToExistingAndNewTools(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + + existing := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("existing", "existing tool"), + } + r.Register(existing) + + r.SetMediaStore(store) + if existing.store != store { + t.Fatal("expected existing tool to receive media store") + } + + later := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("later", "later tool"), + } + r.Register(later) + + if later.store != store { + t.Fatal("expected newly registered tool to inherit media store") + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesLargeBase64Payload(t *testing.T) { + r := NewToolRegistry() + payload := strings.Repeat("QUJD", 400) + r.Register(&mockRegistryTool{ + name: "base64_tool", + desc: "returns huge base64", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "base64_tool", nil, "telegram", "chat-1", nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized payload, got %q", result.ForLLM) + } +} + +func TestToolRegistry_ExecuteWithContext_ExtractsInlineMediaDataURL(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + r.SetMediaStore(store) + + payload := "![screenshot](data:image/png;base64,aGVsbG8=)" + r.Register(&mockRegistryTool{ + name: "inline_media_tool", + desc: "returns inline data url", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_tool", nil, "telegram", "chat-42", nil) + + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be stripped from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "registered as a media attachment") { + t.Fatalf("expected delivery note in ForLLM, got %q", result.ForLLM) + } + + path, err := store.Resolve(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected stored media file to exist: %v", err) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected stored inline media to use png extension, got %q", path) + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesInlineMediaWithoutStore(t *testing.T) { + r := NewToolRegistry() + + payload := "before ![img](data:image/png;base64,aGVsbG8=) after" + r.Register(&mockRegistryTool{ + name: "inline_media_no_store", + desc: "returns inline data url without store", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_no_store", nil, "telegram", "chat-42", nil) + + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be removed from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, inlineMediaOmittedMessage) { + t.Fatalf("expected inline media omission note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/result.go b/pkg/tools/result.go index cab833284..f75ce6d40 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -1,6 +1,14 @@ package tools -import "encoding/json" +import ( + "encoding/json" + "strings" +) + +const ( + handledToolLLMNote = "The requested output has already been delivered to the user in the current chat. Do not call send_file or any other delivery tool again. If you reply, provide only a brief confirmation." + artifactPathsLLMNote = "Use `send_file` with one of these paths to send it to the user, or use file/exec tools to save it inside the workspace if requested." +) // ToolResult represents the structured return value from tool execution. // It provides clear semantics for different types of results and supports @@ -34,6 +42,48 @@ type ToolResult struct { // Media contains media store refs produced by this tool. // When non-empty, the agent will publish these as OutboundMediaMessage. Media []string `json:"media,omitempty"` + + // ArtifactTags exposes local artifact paths back to the LLM in a structured + // form, e.g. "[file:/tmp/example.png]". This is used when a tool produced a + // reusable local artifact but did not deliver it to the user yet. + ArtifactTags []string `json:"artifact_tags,omitempty"` + + // ResponseHandled indicates that this tool execution already satisfied the + // user's request at the channel/output level, so the agent loop can stop + // without a follow-up assistant response. + ResponseHandled bool `json:"response_handled,omitempty"` +} + +// ContentForLLM returns the normalized textual content to append to the +// conversation after a tool call. Errors fall back to Err when ForLLM is empty. +func (tr *ToolResult) ContentForLLM() string { + if tr == nil { + return "" + } + content := tr.ForLLM + if content == "" && tr.Err != nil { + content = tr.Err.Error() + } + if tr.ResponseHandled { + if content == "" { + return handledToolLLMNote + } + if !strings.Contains(content, handledToolLLMNote) { + content += "\n" + handledToolLLMNote + } + } + if len(tr.ArtifactTags) > 0 { + artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote + if content == "" { + content = artifactNote + } else if !strings.Contains(content, artifactNote) { + content += "\n" + artifactNote + } + } + if content != "" { + return content + } + return "" } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -158,3 +208,9 @@ func (tr *ToolResult) WithError(err error) *ToolResult { tr.Err = err return tr } + +// WithResponseHandled marks the tool result as already delivered to the user. +func (tr *ToolResult) WithResponseHandled() *ToolResult { + tr.ResponseHandled = true + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go index a234e33f3..5f08cb4fa 100644 --- a/pkg/tools/result_test.go +++ b/pkg/tools/result_test.go @@ -3,6 +3,7 @@ package tools import ( "encoding/json" "errors" + "strings" "testing" ) @@ -227,3 +228,41 @@ func TestToolResultJSONStructure(t *testing.T) { t.Errorf("Expected silent false, got %v", parsed["silent"]) } } + +func TestToolResultContentForLLM_AppendsHandledDeliveryNote(t *testing.T) { + result := MediaResult("Screenshot attached.", []string{"media://example"}).WithResponseHandled() + + content := result.ContentForLLM() + if !strings.Contains(content, "Screenshot attached.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, handledToolLLMNote) { + t.Fatalf("expected handled delivery note in ContentForLLM, got %q", content) + } +} + +func TestToolResultContentForLLM_UsesHandledDeliveryNoteWhenEmpty(t *testing.T) { + result := (&ToolResult{}).WithResponseHandled() + + if got := result.ContentForLLM(); got != handledToolLLMNote { + t.Fatalf("ContentForLLM() = %q, want %q", got, handledToolLLMNote) + } +} + +func TestToolResultContentForLLM_AppendsArtifactPaths(t *testing.T) { + result := &ToolResult{ + ForLLM: "Artifact created.", + ArtifactTags: []string{"[file:/tmp/example.png]"}, + } + + content := result.ContentForLLM() + if !strings.Contains(content, "Artifact created.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, "Local artifact paths: [file:/tmp/example.png]") { + t.Fatalf("expected artifact path note in ContentForLLM, got %q", content) + } + if !strings.Contains(content, artifactPathsLLMNote) { + t.Fatalf("expected artifact guidance note in ContentForLLM, got %q", content) + } +} diff --git a/pkg/tools/send_file.go b/pkg/tools/send_file.go index a67bd4210..a59ad56aa 100644 --- a/pkg/tools/send_file.go +++ b/pkg/tools/send_file.go @@ -141,7 +141,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult(fmt.Sprintf("failed to register media: %v", err)) } - return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}) + return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled() } // detectMediaType determines the MIME type of a file. diff --git a/pkg/tools/send_file_test.go b/pkg/tools/send_file_test.go index 6daaab31c..cfe5b43e1 100644 --- a/pkg/tools/send_file_test.go +++ b/pkg/tools/send_file_test.go @@ -104,6 +104,9 @@ func TestSendFileTool_Success(t *testing.T) { if result.Media[0][:8] != "media://" { t.Errorf("expected media:// ref, got %q", result.Media[0]) } + if !result.ResponseHandled { + t.Fatal("expected send_file success to mark response handled") + } } func TestSendFileTool_CustomFilename(t *testing.T) { diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 244f0d4a2..387813e94 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -159,10 +159,7 @@ func RunToolLoop( // Append results in original order for _, r := range results { - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() - } + contentForLLM := r.result.ContentForLLM() messages = append(messages, providers.Message{ Role: "tool",