Merge pull request #1889 from afjcjsbx/fix/binary-tool-output-handling

fix(tool): route binary outputs through the media pipeline
This commit is contained in:
Mauro
2026-03-24 15:37:06 +01:00
committed by GitHub
15 changed files with 1963 additions and 144 deletions
+221 -114
View File
@@ -96,14 +96,15 @@ type continuationTarget struct {
}
const (
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
sessionKeyAgentPrefix = "agent:"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
metadataKeyParentPeerKind = "parent_peer_kind"
metadataKeyParentPeerID = "parent_peer_id"
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
handledToolResponseSummary = "Requested output delivered via tool attachment."
sessionKeyAgentPrefix = "agent:"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
metadataKeyParentPeerKind = "parent_peer_kind"
metadataKeyParentPeerID = "parent_peer_id"
)
func NewAgentLoop(
@@ -1030,13 +1031,13 @@ func (al *AgentLoop) GetConfig() *config.Config {
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
// Propagate store to send_file tools in all agents.
// Propagate store to all registered tools that can emit media.
registry := al.GetRegistry()
registry.ForEachTool("send_file", func(t tools.Tool) {
if sf, ok := t.(*tools.SendFileTool); ok {
sf.SetMediaStore(s)
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
agent.Tools.SetMediaStore(s)
}
})
}
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
@@ -2165,6 +2166,7 @@ turnLoop:
"iteration": iteration,
})
allResponsesHandled := len(normalizedToolCalls) > 0
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
@@ -2221,6 +2223,7 @@ turnLoop:
toolArgs = toolReq.Arguments
}
case HookActionDenyTool:
allResponsesHandled = false
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
al.emitEvent(
EventKindToolExecSkipped,
@@ -2260,6 +2263,7 @@ turnLoop:
ChatID: ts.chatID,
})
if !approval.Approved {
allResponsesHandled = false
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
al.emitEvent(
EventKindToolExecSkipped,
@@ -2333,10 +2337,7 @@ turnLoop:
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
content := result.ContentForLLM()
if content == "" {
return
}
@@ -2420,6 +2421,50 @@ turnLoop:
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
outboundMedia := bus.OutboundMediaMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Parts: parts,
}
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
logger.WarnCF("agent", "Failed to deliver handled tool media",
map[string]any{
"agent_id": ts.agent.ID,
"tool": toolName,
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err)
}
} else if al.bus != nil {
al.bus.PublishOutboundMedia(ctx, outboundMedia)
// Queuing media is only best-effort; it has not been delivered yet.
toolResult.ResponseHandled = false
}
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
}
if !toolResult.ResponseHandled {
allResponsesHandled = false
}
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
@@ -2434,30 +2479,7 @@ turnLoop:
})
}
if len(toolResult.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Parts: parts,
})
}
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
contentForLLM := toolResult.ContentForLLM()
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
@@ -2552,6 +2574,70 @@ turnLoop:
}
}
if allResponsesHandled {
if len(pendingMessages) > 0 {
logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(pendingMessages),
"session_key": ts.sessionKey,
})
finalContent = ""
goto turnLoop
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(steerMsgs),
"session_key": ts.sessionKey,
})
pendingMessages = append(pendingMessages, steerMsgs...)
finalContent = ""
goto turnLoop
}
summaryMsg := providers.Message{
Role: "assistant",
Content: handledToolResponseSummary,
}
if !ts.opts.NoHistory {
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
ts.recordPersistedMessage(summaryMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
turnStatus = TurnEndStatusError
al.emitEvent(
EventKindError,
ts.eventMeta("runTurn", "turn.error"),
ErrorPayload{
Stage: "session_save",
Message: err.Error(),
},
)
return turnResult{}, err
}
}
if ts.opts.EnableSummary {
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
}
ts.setPhase(TurnPhaseCompleted)
ts.setFinalContent("")
logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM",
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
"tool_count": len(normalizedToolCalls),
})
return turnResult{
finalContent: "",
status: turnStatus,
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
}, nil
}
ts.agent.Tools.TickTTL()
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
"agent_id": ts.agent.ID, "iteration": iteration,
@@ -3159,6 +3245,97 @@ func (al *AgentLoop) handleCommand(
}
}
func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
if agent == nil {
return nil
}
combined := make([]string, 0, len(agent.SkillsFilter)+len(opts.ForcedSkills))
combined = append(combined, agent.SkillsFilter...)
combined = append(combined, opts.ForcedSkills...)
if len(combined) == 0 {
return nil
}
var resolved []string
seen := make(map[string]struct{}, len(combined))
for _, name := range combined {
name = strings.TrimSpace(name)
if name == "" {
continue
}
if agent.ContextBuilder != nil {
if canonical, ok := agent.ContextBuilder.ResolveSkillName(name); ok {
name = canonical
}
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
resolved = append(resolved, name)
}
return resolved
}
func (al *AgentLoop) applyExplicitSkillCommand(
raw string,
agent *AgentInstance,
opts *processOptions,
) (matched bool, handled bool, reply string) {
cmdName, ok := commands.CommandName(raw)
if !ok || cmdName != "use" {
return false, false, ""
}
if agent == nil || agent.ContextBuilder == nil {
return true, true, commandsUnavailableSkillMessage()
}
parts := strings.Fields(strings.TrimSpace(raw))
if len(parts) < 2 {
return true, true, buildUseCommandHelp(agent)
}
arg := strings.TrimSpace(parts[1])
if strings.EqualFold(arg, "clear") || strings.EqualFold(arg, "off") {
if opts != nil {
al.clearPendingSkills(opts.SessionKey)
}
return true, true, "Cleared pending skill override."
}
skillName, ok := agent.ContextBuilder.ResolveSkillName(arg)
if !ok {
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg)
}
if len(parts) < 3 {
if opts == nil || strings.TrimSpace(opts.SessionKey) == "" {
return true, true, commandsUnavailableSkillMessage()
}
al.setPendingSkills(opts.SessionKey, []string{skillName})
return true, true, fmt.Sprintf(
"Skill %q is armed for your next message. Send your next prompt normally, or use /use clear to cancel.",
skillName,
)
}
message := strings.TrimSpace(strings.Join(parts[2:], " "))
if message == "" {
return true, true, buildUseCommandHelp(agent)
}
if opts != nil {
opts.ForcedSkills = append(opts.ForcedSkills, skillName)
opts.UserMessage = message
}
return true, false, ""
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
registry := al.GetRegistry()
cfg := al.GetConfig()
@@ -3199,6 +3376,9 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return al.reloadFunc()
}
if agent != nil {
if agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
rt.GetModelInfo = func() (string, string) {
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
@@ -3251,79 +3431,6 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return rt
}
func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
var out []string
seen := make(map[string]struct{})
appendNames := func(names []string) {
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
out = append(out, name)
}
}
if agent != nil {
appendNames(agent.SkillsFilter)
}
appendNames(opts.ForcedSkills)
return out
}
func (al *AgentLoop) applyExplicitSkillCommand(
raw string,
agent *AgentInstance,
opts *processOptions,
) (matched bool, handled bool, reply string) {
commandName, ok := commands.CommandName(raw)
if !ok || commandName != "use" {
return false, false, ""
}
if agent == nil || agent.ContextBuilder == nil {
return true, true, commandsUnavailableSkillMessage()
}
fields := strings.Fields(strings.TrimSpace(raw))
if len(fields) < 2 {
return true, true, buildUseCommandHelp(agent)
}
if strings.EqualFold(fields[1], "clear") || strings.EqualFold(fields[1], "off") {
al.clearPendingSkills(opts.SessionKey)
return true, true, "Cleared pending skill override."
}
canonicalSkill, ok := agent.ContextBuilder.ResolveSkillName(fields[1])
if !ok {
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", fields[1])
}
if len(fields) == 2 {
al.setPendingSkills(opts.SessionKey, []string{canonicalSkill})
return true, true, fmt.Sprintf(
"Skill %q is armed for your next message.\nSend your next request normally, or use /use clear to cancel.",
canonicalSkill,
)
}
message := strings.TrimSpace(strings.Join(fields[2:], " "))
if message == "" {
return true, true, buildUseCommandHelp(agent)
}
opts.UserMessage = message
opts.ForcedSkills = append(opts.ForcedSkills, canonicalSkill)
return true, false, ""
}
func commandsUnavailableSkillMessage() string {
return "Skill selection is unavailable in the current context."
}
+18
View File
@@ -87,6 +87,24 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
return result
}
func buildArtifactTags(store media.MediaStore, refs []string) []string {
if store == nil || len(refs) == 0 {
return nil
}
tags := make([]string, 0, len(refs))
for _, ref := range refs {
localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
continue
}
mime := detectMIME(localPath, meta)
tags = append(tags, buildPathTag(mime, localPath))
}
return tags
}
// detectMIME determines the MIME type from metadata or magic-bytes detection.
// Returns empty string if detection fails.
func detectMIME(localPath string, meta media.MediaMeta) string {
+547
View File
@@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
type fakeMediaChannel struct {
fakeChannel
sentMedia []bus.OutboundMediaMessage
}
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
f.sentMedia = append(f.sentMedia, msg)
return nil
}
func newStartedTestChannelManager(
t *testing.T,
msgBus *bus.MessageBus,
store media.MediaStore,
name string,
ch channels.Channel,
) *channels.Manager {
t.Helper()
cm, err := channels.NewManager(&config.Config{}, msgBus, store)
if err != nil {
t.Fatalf("NewManager() error = %v", err)
}
cm.RegisterChannel(name, ch)
if err := cm.StartAll(context.Background()); err != nil {
t.Fatalf("StartAll() error = %v", err)
}
t.Cleanup(func() {
if err := cm.StopAll(context.Background()); err != nil {
t.Fatalf("StopAll() error = %v", err)
}
})
return cm
}
type recordingProvider struct {
lastMessages []providers.Message
}
@@ -289,6 +324,86 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
}
}
func TestApplyExplicitSkillCommand_ArmsSkillForNextMessage(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
t.Fatalf("MkdirAll(skill) error = %v", err)
}
if err := os.WriteFile(
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
0o644,
); err != nil {
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
opts := &processOptions{SessionKey: "agent:main:test"}
matched, handled, reply := al.applyExplicitSkillCommand("/use finance-news", agent, opts)
if !matched {
t.Fatal("expected /use command to match")
}
if !handled {
t.Fatal("expected /use without inline message to be handled immediately")
}
if !strings.Contains(reply, `Skill "finance-news" is armed for your next message`) {
t.Fatalf("unexpected reply: %q", reply)
}
pending := al.takePendingSkills(opts.SessionKey)
if len(pending) != 1 || pending[0] != "finance-news" {
t.Fatalf("pending skills = %#v, want [finance-news]", pending)
}
}
func TestApplyExplicitSkillCommand_InlineMessageMutatesOptions(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
t.Fatalf("MkdirAll(skill) error = %v", err)
}
if err := os.WriteFile(
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
0o644,
); err != nil {
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
opts := &processOptions{
SessionKey: "agent:main:test",
UserMessage: "/use finance-news dammi le ultime news",
}
matched, handled, reply := al.applyExplicitSkillCommand(opts.UserMessage, agent, opts)
if !matched {
t.Fatal("expected /use command to match")
}
if handled {
t.Fatal("expected /use with inline message to fall through into normal agent execution")
}
if reply != "" {
t.Fatalf("unexpected reply: %q", reply)
}
if opts.UserMessage != "dammi le ultime news" {
t.Fatalf("opts.UserMessage = %q, want %q", opts.UserMessage, "dammi le ultime news")
}
if len(opts.ForcedSkills) != 1 || opts.ForcedSkills[0] != "finance-news" {
t.Fatalf("opts.ForcedSkills = %#v, want [finance-news]", opts.ForcedSkills)
}
}
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
@@ -455,6 +570,217 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
}
}
func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &handledMediaProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&handledMediaTool{
store: store,
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "" {
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
}
if provider.calls != 1 {
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
}
if len(provider.toolCounts) != 1 {
t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts))
}
if provider.toolCounts[0] == 0 {
t.Fatal("expected tools to be available on the first LLM call")
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
case extra := <-msgBus.OutboundMediaChan():
t.Fatalf("expected handled media to bypass async queue, got %+v", extra)
default:
}
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
sessionKey := resolveScopeKey(route, "")
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) == 0 {
t.Fatal("expected session history to be saved")
}
last := history[len(history)-1]
if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." {
t.Fatalf("expected handled assistant summary in history, got %+v", last)
}
}
func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &handledMediaWithSteeringProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen-steering.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&handledMediaWithSteeringTool{
store: store,
path: imagePath,
loop: al,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Handled the queued steering message." {
t.Fatalf("response = %q, want queued steering response", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
}
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.ModelName = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &artifactThenSendProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
imagePath := filepath.Join(mediaDir, "artifact-screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&mediaArtifactTool{
store: store,
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "" {
t.Fatalf("expected no final response after send_file handled delivery, got %q", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
case extra := <-msgBus.OutboundMediaChan():
t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra)
default:
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
@@ -600,6 +926,98 @@ func (m *countingMockProvider) GetDefaultModel() string {
return "counting-mock-model"
}
type handledMediaProvider struct {
calls int
toolCounts []int
}
func (m *handledMediaProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
m.toolCounts = append(m.toolCounts, len(tools))
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_handled_media",
Type: "function",
Name: "handled_media_tool",
Arguments: map[string]any{},
}},
}, nil
}
return &providers.LLMResponse{}, nil
}
func (m *handledMediaProvider) GetDefaultModel() string {
return "handled-media-model"
}
type artifactThenSendProvider struct {
calls int
}
func (m *artifactThenSendProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_artifact_media",
Type: "function",
Name: "media_artifact_tool",
Arguments: map[string]any{},
}},
}, nil
}
var artifactPath string
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role != "tool" {
continue
}
start := strings.Index(messages[i].Content, "[file:")
if start < 0 {
continue
}
rest := messages[i].Content[start+len("[file:"):]
end := strings.Index(rest, "]")
if end < 0 {
continue
}
artifactPath = rest[:end]
break
}
if artifactPath == "" {
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
}
return &providers.LLMResponse{
Content: "",
ToolCalls: []providers.ToolCall{{
ID: "call_send_file",
Type: "function",
Name: "send_file",
Arguments: map[string]any{"path": artifactPath},
}},
}, nil
}
func (m *artifactThenSendProvider) GetDefaultModel() string {
return "artifact-then-send-model"
}
type toolLimitOnlyProvider struct{}
func (m *toolLimitOnlyProvider) Chat(
@@ -646,6 +1064,135 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
return tools.SilentResult("Custom tool executed")
}
type handledMediaTool struct {
store media.MediaStore
path string
}
func (m *handledMediaTool) Name() string { return "handled_media_tool" }
func (m *handledMediaTool) Description() string {
return "Returns a media attachment and fully handles the user response"
}
func (m *handledMediaTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:handled_media_tool",
}, "test:handled_media")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type handledMediaWithSteeringProvider struct {
calls int
}
func (m *handledMediaWithSteeringProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_handled_media_steering",
Type: "function",
Name: "handled_media_with_steering_tool",
Arguments: map[string]any{},
}},
}, nil
}
for _, msg := range messages {
if msg.Role == "user" && msg.Content == "what about this instead?" {
return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil
}
}
return nil, fmt.Errorf("provider did not receive queued steering message")
}
func (m *handledMediaWithSteeringProvider) GetDefaultModel() string {
return "handled-media-with-steering-model"
}
type handledMediaWithSteeringTool struct {
store media.MediaStore
path string
loop *AgentLoop
}
func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" }
func (m *handledMediaWithSteeringTool) Description() string {
return "Returns handled media and enqueues a steering message during execution"
}
func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:handled_media_with_steering_tool",
}, "test:handled_media_with_steering")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type mediaArtifactTool struct {
store media.MediaStore
path string
}
func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" }
func (m *mediaArtifactTool) Description() string {
return "Returns a media artifact that the agent can forward or save later"
}
func (m *mediaArtifactTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:media_artifact_tool",
}, "test:media_artifact")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Artifact created.", []string{ref})
}
type toolLimitTestTool struct{}
func (m *toolLimitTestTool) Name() string {
+75 -9
View File
@@ -206,6 +206,40 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
return false
}
// preSendMedia handles typing stop, reaction undo, and placeholder cleanup
// before sending media attachments. Unlike preSend for text messages, media
// delivery never edits the placeholder because there is no text payload to
// replace it with; it only attempts to delete the placeholder when possible.
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
key := name + ":" + msg.ChatID
// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
if entry, ok := v.(typingEntry); ok {
entry.stop() // idempotent, safe
}
}
// 2. Undo reaction
if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
if entry, ok := v.(reactionEntry); ok {
entry.undo() // idempotent, safe
}
}
// 3. Clear any finalized stream marker for this chat before media delivery.
m.streamActive.LoadAndDelete(key)
// 4. Delete placeholder if present.
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
}
}
}
}
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
m := &Manager{
channels: make(map[string]Channel),
@@ -774,7 +808,7 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor
if !ok {
return
}
m.sendMediaWithRetry(ctx, name, w, msg)
_ = m.sendMediaWithRetry(ctx, name, w, msg)
case <-ctx.Done():
return
}
@@ -782,26 +816,37 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor
}
// sendMediaWithRetry sends a media message through the channel with rate limiting and
// retry logic. If the channel does not implement MediaSender, it silently skips.
func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) {
// retry logic. It returns nil on success, or the last error after retries,
// including when the channel does not support MediaSender.
func (m *Manager) sendMediaWithRetry(
ctx context.Context,
name string,
w *channelWorker,
msg bus.OutboundMediaMessage,
) error {
ms, ok := w.ch.(MediaSender)
if !ok {
logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{
err := fmt.Errorf("channel %q does not support media sending", name)
logger.WarnCF("channels", "Channel does not support MediaSender", map[string]any{
"channel": name,
"error": err.Error(),
})
return
return err
}
// Rate limit: wait for token
if err := w.limiter.Wait(ctx); err != nil {
return
return err
}
// Pre-send: stop typing and clean up any placeholder before sending media.
m.preSendMedia(ctx, name, msg, w.ch)
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
lastErr = ms.SendMedia(ctx, msg)
if lastErr == nil {
return
return nil
}
// Permanent failures — don't retry
@@ -820,7 +865,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
case <-time.After(rateLimitDelay):
continue
case <-ctx.Done():
return
return ctx.Err()
}
}
@@ -829,7 +874,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
select {
case <-time.After(backoff):
case <-ctx.Done():
return
return ctx.Err()
}
}
@@ -840,6 +885,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
"error": lastErr.Error(),
"retries": maxRetries,
})
return lastErr
}
// runTTLJanitor periodically scans the typingStops and placeholders maps
@@ -1032,6 +1078,26 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
return nil
}
// SendMedia sends outbound media synchronously through the channel worker's
// rate limiter and retry logic. It blocks until the media is delivered (or all
// retries are exhausted), which preserves ordering when later agent behavior
// depends on actual media delivery.
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
m.mu.RUnlock()
if !exists {
return fmt.Errorf("channel %s not found", msg.Channel)
}
if !wExists || w == nil {
return fmt.Errorf("channel %s has no active worker", msg.Channel)
}
return m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
}
func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error {
m.mu.RLock()
_, exists := m.channels[channelName]
+154
View File
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
@@ -43,6 +44,40 @@ func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, conten
return nil
}
type mockMediaChannel struct {
mockChannel
sendMediaFn func(ctx context.Context, msg bus.OutboundMediaMessage) error
sentMediaMessages []bus.OutboundMediaMessage
}
func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
m.sentMediaMessages = append(m.sentMediaMessages, msg)
if m.sendMediaFn != nil {
return m.sendMediaFn(ctx, msg)
}
return nil
}
type mockDeletingMediaChannel struct {
mockMediaChannel
deleteCalls int
lastDeleted struct {
chatID string
messageID string
}
}
func (m *mockDeletingMediaChannel) DeleteMessage(
_ context.Context,
chatID string,
messageID string,
) error {
m.deleteCalls++
m.lastDeleted.chatID = chatID
m.lastDeleted.messageID = messageID
return nil
}
// newTestManager creates a minimal Manager suitable for unit tests.
func newTestManager() *Manager {
return &Manager{
@@ -208,6 +243,125 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
}
}
func TestSendMedia_Success(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockMediaChannel{
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
callCount++
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
if callCount != 1 {
t.Fatalf("expected 1 SendMedia call, got %d", callCount)
}
}
func TestSendMedia_PropagatesFailure(t *testing.T) {
m := newTestManager()
ch := &mockMediaChannel{
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
return fmt.Errorf("bad upload: %w", ErrSendFailed)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
if err == nil {
t.Fatal("expected SendMedia to return error")
}
if !errors.Is(err, ErrSendFailed) {
t.Fatalf("expected ErrSendFailed, got %v", err)
}
}
func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
m := newTestManager()
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
if err == nil {
t.Fatal("expected SendMedia to return error for unsupported channel")
}
if !strings.Contains(err.Error(), "does not support media sending") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
m := newTestManager()
ch := &mockDeletingMediaChannel{
mockMediaChannel: mockMediaChannel{
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
return nil
},
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
m.RecordPlaceholder("test", "chat1", "placeholder-1")
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
if ch.deleteCalls != 1 {
t.Fatalf("expected placeholder delete to be called once, got %d", ch.deleteCalls)
}
if ch.lastDeleted.chatID != "chat1" || ch.lastDeleted.messageID != "placeholder-1" {
t.Fatalf("unexpected placeholder deletion target: %+v", ch.lastDeleted)
}
if len(ch.sentMediaMessages) != 1 {
t.Fatalf("expected media to be sent once, got %d", len(ch.sentMediaMessages))
}
}
func TestSendWithRetry_UnknownError(t *testing.T) {
m := newTestManager()
var callCount int
+269 -11
View File
@@ -5,9 +5,13 @@ import (
"encoding/json"
"fmt"
"hash/fnv"
"os"
"strings"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/media"
)
// MCPManager defines the interface for MCP manager operations
@@ -25,6 +29,7 @@ type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
mediaStore media.MediaStore
}
// NewMCPTool creates a new MCP tool wrapper
@@ -36,6 +41,10 @@ func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool
}
}
func (t *MCPTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
@@ -218,13 +227,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
// Extract text content from result
output := extractContentText(result.Content)
return &ToolResult{
ForLLM: output,
IsError: false,
}
return t.normalizeResultContent(ctx, result.Content)
}
// extractContentText extracts text from MCP content array
@@ -233,14 +236,269 @@ func extractContentText(content []mcp.Content) string {
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, v.Text)
parts = append(parts, sanitizeToolLLMContent(v.Text))
case *mcp.ImageContent:
// For images, just indicate that an image was returned
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.AudioContent:
parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.ResourceLink:
parts = append(parts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
parts = append(parts, summarizeEmbeddedResource(v))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return strings.Join(parts, "\n")
return sanitizeToolLLMContent(strings.Join(parts, "\n"))
}
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
llmParts := make([]string, 0, len(content))
mediaRefs := make([]string, 0, len(content))
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
text := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
if text != "" {
llmParts = append(llmParts, text)
}
case *mcp.ImageContent:
ref, note := t.storeBinaryContent(
ctx,
"image",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.AudioContent:
ref, note := t.storeBinaryContent(
ctx,
"audio",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.ResourceLink:
llmParts = append(llmParts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
ref, note := t.storeEmbeddedResource(ctx, v)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
default:
llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v))
}
}
result := &ToolResult{
ForLLM: strings.Join(compactStrings(llmParts), "\n"),
Media: mediaRefs,
}
return result
}
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
if content == nil || content.Resource == nil {
return "", "[MCP returned an embedded resource without data.]"
}
resource := content.Resource
if len(resource.Blob) > 0 {
return t.storeBinaryContent(
ctx,
"resource",
normalizedMIMEType(resource.MIMEType),
resource.Blob,
content.Annotations,
)
}
if strings.TrimSpace(resource.Text) != "" {
return "", sanitizeToolLLMContent(resource.Text)
}
return "", summarizeEmbeddedResource(content)
}
func (t *MCPTool) storeBinaryContent(
ctx context.Context,
kind string,
mimeType string,
data []byte,
annotations *mcp.Annotations,
) (string, string) {
if len(data) == 0 {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType)
}
if !annotationsAllowUser(annotations) {
return "", fmt.Sprintf(
"[MCP returned %s content (%s) for non-user audience; omitted from model context.]",
kind,
mimeType,
)
}
if t.mediaStore == nil {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]",
kind,
mimeType,
)
}
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because no target chat was available.]",
kind,
mimeType,
)
}
dir := media.TempDir()
if err := os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
if err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
scope := fmt.Sprintf(
"tool:mcp:%s:%s:%s:%d",
sanitizeIdentifierComponent(t.serverName),
channel,
chatID,
time.Now().UnixNano(),
)
filename := fmt.Sprintf(
"%s_%s%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
ext,
)
ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{
Filename: filename,
ContentType: mimeType,
Source: fmt.Sprintf(
"tool:mcp:%s:%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
),
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be registered as media.]",
kind,
mimeType,
)
}
return ref, fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]",
kind,
mimeType,
)
}
func summarizeResourceLink(content *mcp.ResourceLink) string {
if content == nil {
return "[MCP returned an empty resource link.]"
}
parts := []string{"[MCP returned resource link"}
if content.Name != "" {
parts = append(parts, fmt.Sprintf("name=%q", content.Name))
}
if content.URI != "" {
parts = append(parts, fmt.Sprintf("uri=%q", content.URI))
}
if content.MIMEType != "" {
parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType))
}
if content.Description != "" {
desc := strings.TrimSpace(content.Description)
if len(desc) > 200 {
desc = desc[:200] + "..."
}
parts = append(parts, fmt.Sprintf("description=%q", desc))
}
return strings.Join(parts, ", ") + "]"
}
func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
if content == nil || content.Resource == nil {
return "[MCP returned an embedded resource.]"
}
resource := content.Resource
if resource.URI != "" {
return fmt.Sprintf(
"[MCP returned embedded resource %q (%s).]",
resource.URI,
normalizedMIMEType(resource.MIMEType),
)
}
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
}
func annotationsAllowUser(annotations *mcp.Annotations) bool {
if annotations == nil || len(annotations.Audience) == 0 {
return true
}
for _, audience := range annotations.Audience {
if strings.EqualFold(string(audience), "user") {
return true
}
}
return false
}
func normalizedMIMEType(mimeType string) string {
if strings.TrimSpace(mimeType) == "" {
return "application/octet-stream"
}
return mimeType
}
func compactStrings(parts []string) []string {
compact := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
compact = append(compact, part)
}
return compact
}
+144
View File
@@ -3,10 +3,14 @@ package tools
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/media"
)
// MockMCPManager is a mock implementation of MCPManager interface for testing
@@ -490,3 +494,143 @@ func TestMCPTool_Parameters_MapSchema(t *testing.T) {
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
}
}
func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.ImageContent{
Data: []byte("fake-image-bytes"),
MIMEType: "image/png",
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if result.IsError {
t.Fatalf("expected success, got %q", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
if result.ResponseHandled {
t.Fatal("expected MCP image artifact not to mark response as handled")
}
if !strings.Contains(result.ForLLM, "stored as a local media artifact") {
t.Fatalf("expected local media artifact note, got %q", result.ForLLM)
}
path, meta, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
t.Fatalf("expected stored media ref to resolve: %v", err)
}
if meta.ContentType != "image/png" {
t.Fatalf("expected image/png content type, got %q", meta.ContentType)
}
if filepath.Ext(path) != ".png" {
t.Fatalf("expected png temp file, got %q", path)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected stored media file to be readable: %v", err)
}
if string(data) != "fake-image-bytes" {
t.Fatalf("expected stored media bytes to match input, got %q", string(data))
}
}
func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.EmbeddedResource{
Resource: &mcp.ResourceContents{
URI: "file:///tmp/report.png",
MIMEType: "image/png",
Blob: []byte("blob-bytes"),
},
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if len(result.Media) != 1 {
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
}
path, _, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
t.Fatalf("expected stored media ref to resolve: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected stored media file to be readable: %v", err)
}
if string(data) != "blob-bytes" {
t.Fatalf("expected stored blob bytes to match input, got %q", string(data))
}
}
func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.ImageContent{
Data: []byte("assistant-only"),
MIMEType: "image/png",
Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}},
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if len(result.Media) != 0 {
t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media))
}
if !strings.Contains(result.ForLLM, "non-user audience") {
t.Fatalf("expected audience note, got %q", result.ForLLM)
}
}
func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: strings.Repeat("QUJD", 400)},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
result := mcpTool.Execute(context.Background(), nil)
if result.ForLLM != largeBase64OmittedMessage {
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
}
}
+292
View File
@@ -0,0 +1,292 @@
package tools
import (
"encoding/base64"
"fmt"
"mime"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"unicode"
"github.com/sipeed/picoclaw/pkg/media"
)
const (
largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]"
inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]"
inlineMediaStoredMessage = "[Tool returned inline media content (%s); omitted from model context and registered as a media attachment.]"
)
var (
inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`)
inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`)
)
func normalizeToolResult(
result *ToolResult,
toolName string,
store media.MediaStore,
channel string,
chatID string,
) *ToolResult {
if result == nil {
return nil
}
notes := make([]string, 0, 2)
seen := make(map[string]struct{})
if store != nil && channel != "" && chatID != "" {
var refs []string
var extractedNotes []string
result.ForLLM, refs, extractedNotes = extractInlineMediaRefs(
result.ForLLM,
toolName,
store,
channel,
chatID,
seen,
)
result.Media = append(result.Media, refs...)
notes = append(notes, extractedNotes...)
result.ForUser, refs, extractedNotes = extractInlineMediaRefs(
result.ForUser,
toolName,
store,
channel,
chatID,
seen,
)
result.Media = append(result.Media, refs...)
notes = append(notes, extractedNotes...)
}
result.ForLLM = sanitizeToolLLMContent(result.ForLLM)
if len(result.Media) > 0 && len(notes) > 0 {
if strings.TrimSpace(result.ForLLM) == "" {
result.ForLLM = strings.Join(notes, "\n")
} else {
result.ForLLM = strings.TrimSpace(result.ForLLM) + "\n" + strings.Join(notes, "\n")
}
}
if len(result.Media) > 0 && strings.TrimSpace(result.ForLLM) == "" {
result.ForLLM = "[Tool returned media content; omitted from model context and registered as a media attachment.]"
}
return result
}
func sanitizeToolLLMContent(text string) string {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return text
}
if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) {
cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "")
cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "")
cleaned = strings.TrimSpace(cleaned)
if cleaned == "" {
return inlineMediaOmittedMessage
}
return cleaned + "\n" + inlineMediaOmittedMessage
}
if looksLikeLargeBase64Payload(trimmed) {
return largeBase64OmittedMessage
}
return text
}
func looksLikeLargeBase64Payload(text string) bool {
trimmed := strings.TrimSpace(text)
if len(trimmed) < 1024 {
return false
}
nonSpace := 0
base64Like := 0
spaceCount := 0
for _, r := range trimmed {
if unicode.IsSpace(r) {
spaceCount++
continue
}
nonSpace++
if (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '+' || r == '/' || r == '=' {
base64Like++
}
}
if nonSpace == 0 {
return false
}
ratio := float64(base64Like) / float64(nonSpace)
return ratio >= 0.97 && spaceCount <= len(trimmed)/128
}
func extractInlineMediaRefs(
text string,
toolName string,
store media.MediaStore,
channel string,
chatID string,
seen map[string]struct{},
) (cleaned string, refs []string, notes []string) {
cleaned = text
matches := inlineMarkdownDataURLRe.FindAllStringSubmatch(cleaned, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
dataURL := match[1]
ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen)
if ref != "" {
refs = append(refs, ref)
}
if note != "" {
notes = append(notes, note)
}
cleaned = strings.ReplaceAll(cleaned, match[0], "")
}
rawMatches := inlineRawDataURLRe.FindAllString(cleaned, -1)
for _, dataURL := range rawMatches {
ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen)
if ref != "" {
refs = append(refs, ref)
}
if note != "" {
notes = append(notes, note)
}
cleaned = strings.ReplaceAll(cleaned, dataURL, "")
}
return strings.TrimSpace(cleaned), refs, notes
}
func storeInlineDataURL(
toolName string,
store media.MediaStore,
channel string,
chatID string,
dataURL string,
seen map[string]struct{},
) (ref string, note string) {
dataURL = strings.TrimSpace(dataURL)
if _, ok := seen[dataURL]; ok {
return "", ""
}
seen[dataURL] = struct{}{}
if !strings.HasPrefix(strings.ToLower(dataURL), "data:") {
return "", ""
}
comma := strings.IndexByte(dataURL, ',')
if comma <= 5 {
return "", "[Tool returned inline media content that could not be parsed.]"
}
metaPart := dataURL[:comma]
payload := dataURL[comma+1:]
if !strings.Contains(strings.ToLower(metaPart), ";base64") {
return "", "[Tool returned inline media content that was not base64-encoded.]"
}
mimeType := strings.TrimSpace(strings.TrimPrefix(metaPart, "data:"))
if semi := strings.IndexByte(mimeType, ';'); semi >= 0 {
mimeType = mimeType[:semi]
}
if mimeType == "" {
mimeType = "application/octet-stream"
}
payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload)
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType)
}
dir := media.TempDir()
if err = os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext)
if err != nil {
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(decoded); err != nil {
tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
filename := sanitizeIdentifierComponent(toolName) + ext
scope := fmt.Sprintf(
"tool:inline:%s:%s:%s:%d",
sanitizeIdentifierComponent(toolName),
channel,
chatID,
time.Now().UnixNano(),
)
ref, err = store.Store(tmpPath, media.MediaMeta{
Filename: filename,
ContentType: mimeType,
Source: fmt.Sprintf("tool:inline:%s", sanitizeIdentifierComponent(toolName)),
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType)
}
return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType)
}
func extensionForMIMEType(mimeType string) string {
if mimeType == "" {
return ".bin"
}
if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 {
return exts[0]
}
switch strings.ToLower(mimeType) {
case "image/jpeg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "audio/wav", "audio/x-wav":
return ".wav"
case "audio/mpeg":
return ".mp3"
case "audio/ogg":
return ".ogg"
case "video/mp4":
return ".mp4"
default:
return filepath.Ext(mimeType)
}
}
+34 -5
View File
@@ -9,6 +9,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -19,9 +20,14 @@ type ToolEntry struct {
}
type ToolRegistry struct {
tools map[string]*ToolEntry
mu sync.RWMutex
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
tools map[string]*ToolEntry
mu sync.RWMutex
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
mediaStore media.MediaStore
}
type mediaStoreAware interface {
SetMediaStore(store media.MediaStore)
}
func NewToolRegistry() *ToolRegistry {
@@ -43,6 +49,9 @@ func (r *ToolRegistry) Register(tool Tool) {
IsCore: true,
TTL: 0, // Core tools do not use TTL
}
if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil {
aware.SetMediaStore(r.mediaStore)
}
r.version.Add(1)
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
}
@@ -61,10 +70,27 @@ func (r *ToolRegistry) RegisterHidden(tool Tool) {
IsCore: false,
TTL: 0,
}
if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil {
aware.SetMediaStore(r.mediaStore)
}
r.version.Add(1)
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
}
// SetMediaStore injects a MediaStore into all registered tools that can
// consume it, and remembers it for future registrations.
func (r *ToolRegistry) SetMediaStore(store media.MediaStore) {
r.mu.Lock()
defer r.mu.Unlock()
r.mediaStore = store
for _, entry := range r.tools {
if aware, ok := entry.Tool.(mediaStoreAware); ok {
aware.SetMediaStore(store)
}
}
}
// PromoteTools atomically sets the TTL for multiple non-core tools.
// This prevents a concurrent TickTTL from decrementing between promotions.
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
@@ -238,6 +264,8 @@ func (r *ToolRegistry) ExecuteWithContext(
}
}
result = normalizeToolResult(result, name, r.mediaStore, channel, chatID)
duration := time.Since(start)
// Log based on result type
@@ -259,7 +287,7 @@ func (r *ToolRegistry) ExecuteWithContext(
map[string]any{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result.ForLLM),
"result_length": len(result.ContentForLLM()),
})
}
@@ -354,7 +382,8 @@ func (r *ToolRegistry) Clone() *ToolRegistry {
r.mu.RLock()
defer r.mu.RUnlock()
clone := &ToolRegistry{
tools: make(map[string]*ToolEntry, len(r.tools)),
tools: make(map[string]*ToolEntry, len(r.tools)),
mediaStore: r.mediaStore,
}
for name, entry := range r.tools {
clone.tools[name] = &ToolEntry{
+111
View File
@@ -3,10 +3,13 @@ package tools
import (
"context"
"errors"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -46,6 +49,15 @@ func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]
return m.result
}
type mockMediaStoreAwareTool struct {
mockRegistryTool
store media.MediaStore
}
func (m *mockMediaStoreAwareTool) SetMediaStore(store media.MediaStore) {
m.store = store
}
// --- helpers ---
func newMockTool(name, desc string) *mockRegistryTool {
@@ -621,3 +633,102 @@ func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) {
t.Errorf("expected 'success', got %q", result2.ForLLM)
}
}
func TestToolRegistry_SetMediaStore_PropagatesToExistingAndNewTools(t *testing.T) {
r := NewToolRegistry()
store := media.NewFileMediaStore()
existing := &mockMediaStoreAwareTool{
mockRegistryTool: *newMockTool("existing", "existing tool"),
}
r.Register(existing)
r.SetMediaStore(store)
if existing.store != store {
t.Fatal("expected existing tool to receive media store")
}
later := &mockMediaStoreAwareTool{
mockRegistryTool: *newMockTool("later", "later tool"),
}
r.Register(later)
if later.store != store {
t.Fatal("expected newly registered tool to inherit media store")
}
}
func TestToolRegistry_ExecuteWithContext_SanitizesLargeBase64Payload(t *testing.T) {
r := NewToolRegistry()
payload := strings.Repeat("QUJD", 400)
r.Register(&mockRegistryTool{
name: "base64_tool",
desc: "returns huge base64",
params: map[string]any{},
result: SilentResult(payload),
})
result := r.ExecuteWithContext(context.Background(), "base64_tool", nil, "telegram", "chat-1", nil)
if result.ForLLM != largeBase64OmittedMessage {
t.Fatalf("expected sanitized payload, got %q", result.ForLLM)
}
}
func TestToolRegistry_ExecuteWithContext_ExtractsInlineMediaDataURL(t *testing.T) {
r := NewToolRegistry()
store := media.NewFileMediaStore()
r.SetMediaStore(store)
payload := "![screenshot](data:image/png;base64,aGVsbG8=)"
r.Register(&mockRegistryTool{
name: "inline_media_tool",
desc: "returns inline data url",
params: map[string]any{},
result: SilentResult(payload),
})
result := r.ExecuteWithContext(context.Background(), "inline_media_tool", nil, "telegram", "chat-42", nil)
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
if strings.Contains(result.ForLLM, "data:image/png;base64") {
t.Fatalf("expected inline data URL to be stripped from ForLLM, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "registered as a media attachment") {
t.Fatalf("expected delivery note in ForLLM, got %q", result.ForLLM)
}
path, err := store.Resolve(result.Media[0])
if err != nil {
t.Fatalf("expected stored media ref to resolve: %v", err)
}
if _, err := os.Stat(path); err != nil {
t.Fatalf("expected stored media file to exist: %v", err)
}
if filepath.Ext(path) != ".png" {
t.Fatalf("expected stored inline media to use png extension, got %q", path)
}
}
func TestToolRegistry_ExecuteWithContext_SanitizesInlineMediaWithoutStore(t *testing.T) {
r := NewToolRegistry()
payload := "before ![img](data:image/png;base64,aGVsbG8=) after"
r.Register(&mockRegistryTool{
name: "inline_media_no_store",
desc: "returns inline data url without store",
params: map[string]any{},
result: SilentResult(payload),
})
result := r.ExecuteWithContext(context.Background(), "inline_media_no_store", nil, "telegram", "chat-42", nil)
if strings.Contains(result.ForLLM, "data:image/png;base64") {
t.Fatalf("expected inline data URL to be removed from ForLLM, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, inlineMediaOmittedMessage) {
t.Fatalf("expected inline media omission note, got %q", result.ForLLM)
}
}
+54
View File
@@ -2,10 +2,16 @@ package tools
import (
"encoding/json"
"strings"
"github.com/sipeed/picoclaw/pkg/providers"
)
const (
handledToolLLMNote = "The requested output has already been delivered to the user in the current chat. Do not call send_file or any other delivery tool again. If you reply, provide only a brief confirmation."
artifactPathsLLMNote = "Use `send_file` with one of these paths to send it to the user, or use file/exec tools to save it inside the workspace if requested."
)
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
@@ -43,6 +49,48 @@ type ToolResult struct {
// Only populated by SubTurn executions; used by evaluator_optimizer
// to carry stateful worker context across evaluation iterations.
Messages []providers.Message `json:"-"`
// ArtifactTags exposes local artifact paths back to the LLM in a structured
// form, e.g. "[file:/tmp/example.png]". This is used when a tool produced a
// reusable local artifact but did not deliver it to the user yet.
ArtifactTags []string `json:"artifact_tags,omitempty"`
// ResponseHandled indicates that this tool execution already satisfied the
// user's request at the channel/output level, so the agent loop can stop
// without a follow-up assistant response.
ResponseHandled bool `json:"response_handled,omitempty"`
}
// ContentForLLM returns the normalized textual content to append to the
// conversation after a tool call. Errors fall back to Err when ForLLM is empty.
func (tr *ToolResult) ContentForLLM() string {
if tr == nil {
return ""
}
content := tr.ForLLM
if content == "" && tr.Err != nil {
content = tr.Err.Error()
}
if tr.ResponseHandled {
if content == "" {
return handledToolLLMNote
}
if !strings.Contains(content, handledToolLLMNote) {
content += "\n" + handledToolLLMNote
}
}
if len(tr.ArtifactTags) > 0 {
artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote
if content == "" {
content = artifactNote
} else if !strings.Contains(content, artifactNote) {
content += "\n" + artifactNote
}
}
if content != "" {
return content
}
return ""
}
// NewToolResult creates a basic ToolResult with content for the LLM.
@@ -167,3 +215,9 @@ func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}
// WithResponseHandled marks the tool result as already delivered to the user.
func (tr *ToolResult) WithResponseHandled() *ToolResult {
tr.ResponseHandled = true
return tr
}
+39
View File
@@ -3,6 +3,7 @@ package tools
import (
"encoding/json"
"errors"
"strings"
"testing"
)
@@ -227,3 +228,41 @@ func TestToolResultJSONStructure(t *testing.T) {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}
func TestToolResultContentForLLM_AppendsHandledDeliveryNote(t *testing.T) {
result := MediaResult("Screenshot attached.", []string{"media://example"}).WithResponseHandled()
content := result.ContentForLLM()
if !strings.Contains(content, "Screenshot attached.") {
t.Fatalf("expected original content in ContentForLLM, got %q", content)
}
if !strings.Contains(content, handledToolLLMNote) {
t.Fatalf("expected handled delivery note in ContentForLLM, got %q", content)
}
}
func TestToolResultContentForLLM_UsesHandledDeliveryNoteWhenEmpty(t *testing.T) {
result := (&ToolResult{}).WithResponseHandled()
if got := result.ContentForLLM(); got != handledToolLLMNote {
t.Fatalf("ContentForLLM() = %q, want %q", got, handledToolLLMNote)
}
}
func TestToolResultContentForLLM_AppendsArtifactPaths(t *testing.T) {
result := &ToolResult{
ForLLM: "Artifact created.",
ArtifactTags: []string{"[file:/tmp/example.png]"},
}
content := result.ContentForLLM()
if !strings.Contains(content, "Artifact created.") {
t.Fatalf("expected original content in ContentForLLM, got %q", content)
}
if !strings.Contains(content, "Local artifact paths: [file:/tmp/example.png]") {
t.Fatalf("expected artifact path note in ContentForLLM, got %q", content)
}
if !strings.Contains(content, artifactPathsLLMNote) {
t.Fatalf("expected artifact guidance note in ContentForLLM, got %q", content)
}
}
+1 -1
View File
@@ -142,7 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
}
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref})
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled()
}
// detectMediaType determines the MIME type of a file.
+3
View File
@@ -104,6 +104,9 @@ func TestSendFileTool_Success(t *testing.T) {
if result.Media[0][:8] != "media://" {
t.Errorf("expected media:// ref, got %q", result.Media[0])
}
if !result.ResponseHandled {
t.Fatal("expected send_file success to mark response handled")
}
_, meta, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
+1 -4
View File
@@ -159,10 +159,7 @@ func RunToolLoop(
// Append results in original order
for _, r := range results {
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
contentForLLM := r.result.ContentForLLM()
messages = append(messages, providers.Message{
Role: "tool",