mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1889 from afjcjsbx/fix/binary-tool-output-handling
fix(tool): route binary outputs through the media pipeline
This commit is contained in:
+221
-114
@@ -96,14 +96,15 @@ type continuationTarget struct {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
handledToolResponseSummary = "Requested output delivered via tool attachment."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
)
|
||||
|
||||
func NewAgentLoop(
|
||||
@@ -1030,13 +1031,13 @@ func (al *AgentLoop) GetConfig() *config.Config {
|
||||
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
al.mediaStore = s
|
||||
|
||||
// Propagate store to send_file tools in all agents.
|
||||
// Propagate store to all registered tools that can emit media.
|
||||
registry := al.GetRegistry()
|
||||
registry.ForEachTool("send_file", func(t tools.Tool) {
|
||||
if sf, ok := t.(*tools.SendFileTool); ok {
|
||||
sf.SetMediaStore(s)
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
agent.Tools.SetMediaStore(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
|
||||
@@ -2165,6 +2166,7 @@ turnLoop:
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
allResponsesHandled := len(normalizedToolCalls) > 0
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
@@ -2221,6 +2223,7 @@ turnLoop:
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionDenyTool:
|
||||
allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
@@ -2260,6 +2263,7 @@ turnLoop:
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
if !approval.Approved {
|
||||
allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
@@ -2333,10 +2337,7 @@ turnLoop:
|
||||
}
|
||||
|
||||
// Determine content for the agent loop (ForLLM or error).
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
content := result.ContentForLLM()
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
@@ -2420,6 +2421,50 @@ turnLoop:
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver handled tool media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
// Queuing media is only best-effort; it has not been delivered yet.
|
||||
toolResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
|
||||
}
|
||||
|
||||
if !toolResult.ResponseHandled {
|
||||
allResponsesHandled = false
|
||||
}
|
||||
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
@@ -2434,30 +2479,7 @@ turnLoop:
|
||||
})
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
contentForLLM := toolResult.ContentForLLM()
|
||||
|
||||
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
@@ -2552,6 +2574,70 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
|
||||
if allResponsesHandled {
|
||||
if len(pendingMessages) > 0 {
|
||||
logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(pendingMessages),
|
||||
"session_key": ts.sessionKey,
|
||||
})
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(steerMsgs),
|
||||
"session_key": ts.sessionKey,
|
||||
})
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
}
|
||||
|
||||
summaryMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: handledToolResponseSummary,
|
||||
}
|
||||
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
|
||||
ts.recordPersistedMessage(summaryMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("runTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "session_save",
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return turnResult{}, err
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
ts.setFinalContent("")
|
||||
logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"tool_count": len(normalizedToolCalls),
|
||||
})
|
||||
return turnResult{
|
||||
finalContent: "",
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ts.agent.Tools.TickTTL()
|
||||
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
|
||||
"agent_id": ts.agent.ID, "iteration": iteration,
|
||||
@@ -3159,6 +3245,97 @@ func (al *AgentLoop) handleCommand(
|
||||
}
|
||||
}
|
||||
|
||||
func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
|
||||
if agent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
combined := make([]string, 0, len(agent.SkillsFilter)+len(opts.ForcedSkills))
|
||||
combined = append(combined, agent.SkillsFilter...)
|
||||
combined = append(combined, opts.ForcedSkills...)
|
||||
if len(combined) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var resolved []string
|
||||
seen := make(map[string]struct{}, len(combined))
|
||||
for _, name := range combined {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if agent.ContextBuilder != nil {
|
||||
if canonical, ok := agent.ContextBuilder.ResolveSkillName(name); ok {
|
||||
name = canonical
|
||||
}
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
resolved = append(resolved, name)
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
func (al *AgentLoop) applyExplicitSkillCommand(
|
||||
raw string,
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
) (matched bool, handled bool, reply string) {
|
||||
cmdName, ok := commands.CommandName(raw)
|
||||
if !ok || cmdName != "use" {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
if agent == nil || agent.ContextBuilder == nil {
|
||||
return true, true, commandsUnavailableSkillMessage()
|
||||
}
|
||||
|
||||
parts := strings.Fields(strings.TrimSpace(raw))
|
||||
if len(parts) < 2 {
|
||||
return true, true, buildUseCommandHelp(agent)
|
||||
}
|
||||
|
||||
arg := strings.TrimSpace(parts[1])
|
||||
if strings.EqualFold(arg, "clear") || strings.EqualFold(arg, "off") {
|
||||
if opts != nil {
|
||||
al.clearPendingSkills(opts.SessionKey)
|
||||
}
|
||||
return true, true, "Cleared pending skill override."
|
||||
}
|
||||
|
||||
skillName, ok := agent.ContextBuilder.ResolveSkillName(arg)
|
||||
if !ok {
|
||||
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg)
|
||||
}
|
||||
|
||||
if len(parts) < 3 {
|
||||
if opts == nil || strings.TrimSpace(opts.SessionKey) == "" {
|
||||
return true, true, commandsUnavailableSkillMessage()
|
||||
}
|
||||
al.setPendingSkills(opts.SessionKey, []string{skillName})
|
||||
return true, true, fmt.Sprintf(
|
||||
"Skill %q is armed for your next message. Send your next prompt normally, or use /use clear to cancel.",
|
||||
skillName,
|
||||
)
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(strings.Join(parts[2:], " "))
|
||||
if message == "" {
|
||||
return true, true, buildUseCommandHelp(agent)
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
opts.ForcedSkills = append(opts.ForcedSkills, skillName)
|
||||
opts.UserMessage = message
|
||||
}
|
||||
|
||||
return true, false, ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
|
||||
registry := al.GetRegistry()
|
||||
cfg := al.GetConfig()
|
||||
@@ -3199,6 +3376,9 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
return al.reloadFunc()
|
||||
}
|
||||
if agent != nil {
|
||||
if agent.ContextBuilder != nil {
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
@@ -3251,79 +3431,6 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
return rt
|
||||
}
|
||||
|
||||
func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
|
||||
var out []string
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
appendNames := func(names []string) {
|
||||
for _, name := range names {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[name]; exists {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
|
||||
if agent != nil {
|
||||
appendNames(agent.SkillsFilter)
|
||||
}
|
||||
appendNames(opts.ForcedSkills)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (al *AgentLoop) applyExplicitSkillCommand(
|
||||
raw string,
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
) (matched bool, handled bool, reply string) {
|
||||
commandName, ok := commands.CommandName(raw)
|
||||
if !ok || commandName != "use" {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
if agent == nil || agent.ContextBuilder == nil {
|
||||
return true, true, commandsUnavailableSkillMessage()
|
||||
}
|
||||
|
||||
fields := strings.Fields(strings.TrimSpace(raw))
|
||||
if len(fields) < 2 {
|
||||
return true, true, buildUseCommandHelp(agent)
|
||||
}
|
||||
|
||||
if strings.EqualFold(fields[1], "clear") || strings.EqualFold(fields[1], "off") {
|
||||
al.clearPendingSkills(opts.SessionKey)
|
||||
return true, true, "Cleared pending skill override."
|
||||
}
|
||||
|
||||
canonicalSkill, ok := agent.ContextBuilder.ResolveSkillName(fields[1])
|
||||
if !ok {
|
||||
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", fields[1])
|
||||
}
|
||||
|
||||
if len(fields) == 2 {
|
||||
al.setPendingSkills(opts.SessionKey, []string{canonicalSkill})
|
||||
return true, true, fmt.Sprintf(
|
||||
"Skill %q is armed for your next message.\nSend your next request normally, or use /use clear to cancel.",
|
||||
canonicalSkill,
|
||||
)
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(strings.Join(fields[2:], " "))
|
||||
if message == "" {
|
||||
return true, true, buildUseCommandHelp(agent)
|
||||
}
|
||||
|
||||
opts.UserMessage = message
|
||||
opts.ForcedSkills = append(opts.ForcedSkills, canonicalSkill)
|
||||
return true, false, ""
|
||||
}
|
||||
|
||||
func commandsUnavailableSkillMessage() string {
|
||||
return "Skill selection is unavailable in the current context."
|
||||
}
|
||||
|
||||
@@ -87,6 +87,24 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
|
||||
return result
|
||||
}
|
||||
|
||||
func buildArtifactTags(store media.MediaStore, refs []string) []string {
|
||||
if store == nil || len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags := make([]string, 0, len(refs))
|
||||
for _, ref := range refs {
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
mime := detectMIME(localPath, meta)
|
||||
tags = append(tags, buildPathTag(mime, localPath))
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
// detectMIME determines the MIME type from metadata or magic-bytes detection.
|
||||
// Returns empty string if detection fails.
|
||||
func detectMIME(localPath string, meta media.MediaMeta) string {
|
||||
|
||||
@@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
type fakeMediaChannel struct {
|
||||
fakeChannel
|
||||
sentMedia []bus.OutboundMediaMessage
|
||||
}
|
||||
|
||||
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
f.sentMedia = append(f.sentMedia, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newStartedTestChannelManager(
|
||||
t *testing.T,
|
||||
msgBus *bus.MessageBus,
|
||||
store media.MediaStore,
|
||||
name string,
|
||||
ch channels.Channel,
|
||||
) *channels.Manager {
|
||||
t.Helper()
|
||||
|
||||
cm, err := channels.NewManager(&config.Config{}, msgBus, store)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager() error = %v", err)
|
||||
}
|
||||
cm.RegisterChannel(name, ch)
|
||||
if err := cm.StartAll(context.Background()); err != nil {
|
||||
t.Fatalf("StartAll() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := cm.StopAll(context.Background()); err != nil {
|
||||
t.Fatalf("StopAll() error = %v", err)
|
||||
}
|
||||
})
|
||||
return cm
|
||||
}
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
}
|
||||
@@ -289,6 +324,86 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyExplicitSkillCommand_ArmsSkillForNextMessage(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(skill) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
|
||||
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
opts := &processOptions{SessionKey: "agent:main:test"}
|
||||
matched, handled, reply := al.applyExplicitSkillCommand("/use finance-news", agent, opts)
|
||||
if !matched {
|
||||
t.Fatal("expected /use command to match")
|
||||
}
|
||||
if !handled {
|
||||
t.Fatal("expected /use without inline message to be handled immediately")
|
||||
}
|
||||
if !strings.Contains(reply, `Skill "finance-news" is armed for your next message`) {
|
||||
t.Fatalf("unexpected reply: %q", reply)
|
||||
}
|
||||
|
||||
pending := al.takePendingSkills(opts.SessionKey)
|
||||
if len(pending) != 1 || pending[0] != "finance-news" {
|
||||
t.Fatalf("pending skills = %#v, want [finance-news]", pending)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyExplicitSkillCommand_InlineMessageMutatesOptions(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(skill) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"),
|
||||
[]byte("# Finance News\n\nUse web tools for current finance updates.\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(SKILL.md) error = %v", err)
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
opts := &processOptions{
|
||||
SessionKey: "agent:main:test",
|
||||
UserMessage: "/use finance-news dammi le ultime news",
|
||||
}
|
||||
matched, handled, reply := al.applyExplicitSkillCommand(opts.UserMessage, agent, opts)
|
||||
if !matched {
|
||||
t.Fatal("expected /use command to match")
|
||||
}
|
||||
if handled {
|
||||
t.Fatal("expected /use with inline message to fall through into normal agent execution")
|
||||
}
|
||||
if reply != "" {
|
||||
t.Fatalf("unexpected reply: %q", reply)
|
||||
}
|
||||
if opts.UserMessage != "dammi le ultime news" {
|
||||
t.Fatalf("opts.UserMessage = %q, want %q", opts.UserMessage, "dammi le ultime news")
|
||||
}
|
||||
if len(opts.ForcedSkills) != 1 || opts.ForcedSkills[0] != "finance-news" {
|
||||
t.Fatalf("opts.ForcedSkills = %#v, want [finance-news]", opts.ForcedSkills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
@@ -455,6 +570,217 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &handledMediaProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
||||
}
|
||||
|
||||
al.RegisterTool(&handledMediaTool{
|
||||
store: store,
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "" {
|
||||
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
|
||||
}
|
||||
if provider.calls != 1 {
|
||||
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
|
||||
}
|
||||
if len(provider.toolCounts) != 1 {
|
||||
t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts))
|
||||
}
|
||||
if provider.toolCounts[0] == 0 {
|
||||
t.Fatal("expected tools to be available on the first LLM call")
|
||||
}
|
||||
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
case extra := <-msgBus.OutboundMediaChan():
|
||||
t.Fatalf("expected handled media to bypass async queue, got %+v", extra)
|
||||
default:
|
||||
}
|
||||
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
sessionKey := resolveScopeKey(route, "")
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
t.Fatal("expected session history to be saved")
|
||||
}
|
||||
last := history[len(history)-1]
|
||||
if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." {
|
||||
t.Fatalf("expected handled assistant summary in history, got %+v", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &handledMediaWithSteeringProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen-steering.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
||||
}
|
||||
|
||||
al.RegisterTool(&handledMediaWithSteeringTool{
|
||||
store: store,
|
||||
path: imagePath,
|
||||
loop: al,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Handled the queued steering message." {
|
||||
t.Fatalf("response = %q, want queued steering response", response)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
|
||||
}
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = tmpDir
|
||||
cfg.Agents.Defaults.ModelName = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &artifactThenSendProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
imagePath := filepath.Join(mediaDir, "artifact-screen.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
||||
}
|
||||
|
||||
al.RegisterTool(&mediaArtifactTool{
|
||||
store: store,
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "" {
|
||||
t.Fatalf("expected no final response after send_file handled delivery, got %q", response)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
|
||||
}
|
||||
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
case extra := <-msgBus.OutboundMediaChan():
|
||||
t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
|
||||
func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
@@ -600,6 +926,98 @@ func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
type handledMediaProvider struct {
|
||||
calls int
|
||||
toolCounts []int
|
||||
}
|
||||
|
||||
func (m *handledMediaProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
m.toolCounts = append(m.toolCounts, len(tools))
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Taking the screenshot now.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_handled_media",
|
||||
Type: "function",
|
||||
Name: "handled_media_tool",
|
||||
Arguments: map[string]any{},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
return &providers.LLMResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *handledMediaProvider) GetDefaultModel() string {
|
||||
return "handled-media-model"
|
||||
}
|
||||
|
||||
type artifactThenSendProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *artifactThenSendProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Taking the screenshot now.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_artifact_media",
|
||||
Type: "function",
|
||||
Name: "media_artifact_tool",
|
||||
Arguments: map[string]any{},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var artifactPath string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role != "tool" {
|
||||
continue
|
||||
}
|
||||
start := strings.Index(messages[i].Content, "[file:")
|
||||
if start < 0 {
|
||||
continue
|
||||
}
|
||||
rest := messages[i].Content[start+len("[file:"):]
|
||||
end := strings.Index(rest, "]")
|
||||
if end < 0 {
|
||||
continue
|
||||
}
|
||||
artifactPath = rest[:end]
|
||||
break
|
||||
}
|
||||
if artifactPath == "" {
|
||||
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_send_file",
|
||||
Type: "function",
|
||||
Name: "send_file",
|
||||
Arguments: map[string]any{"path": artifactPath},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *artifactThenSendProvider) GetDefaultModel() string {
|
||||
return "artifact-then-send-model"
|
||||
}
|
||||
|
||||
type toolLimitOnlyProvider struct{}
|
||||
|
||||
func (m *toolLimitOnlyProvider) Chat(
|
||||
@@ -646,6 +1064,135 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
type handledMediaTool struct {
|
||||
store media.MediaStore
|
||||
path string
|
||||
}
|
||||
|
||||
func (m *handledMediaTool) Name() string { return "handled_media_tool" }
|
||||
func (m *handledMediaTool) Description() string {
|
||||
return "Returns a media attachment and fully handles the user response"
|
||||
}
|
||||
|
||||
func (m *handledMediaTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
ref, err := m.store.Store(m.path, media.MediaMeta{
|
||||
Filename: filepath.Base(m.path),
|
||||
ContentType: "image/png",
|
||||
Source: "test:handled_media_tool",
|
||||
}, "test:handled_media")
|
||||
if err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
type handledMediaWithSteeringProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Taking the screenshot now.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_handled_media_steering",
|
||||
Type: "function",
|
||||
Name: "handled_media_with_steering_tool",
|
||||
Arguments: map[string]any{},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "user" && msg.Content == "what about this instead?" {
|
||||
return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("provider did not receive queued steering message")
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringProvider) GetDefaultModel() string {
|
||||
return "handled-media-with-steering-model"
|
||||
}
|
||||
|
||||
type handledMediaWithSteeringTool struct {
|
||||
store media.MediaStore
|
||||
path string
|
||||
loop *AgentLoop
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" }
|
||||
func (m *handledMediaWithSteeringTool) Description() string {
|
||||
return "Returns handled media and enqueues a steering message during execution"
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
|
||||
ref, err := m.store.Store(m.path, media.MediaMeta{
|
||||
Filename: filepath.Base(m.path),
|
||||
ContentType: "image/png",
|
||||
Source: "test:handled_media_with_steering_tool",
|
||||
}, "test:handled_media_with_steering")
|
||||
if err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
type mediaArtifactTool struct {
|
||||
store media.MediaStore
|
||||
path string
|
||||
}
|
||||
|
||||
func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" }
|
||||
func (m *mediaArtifactTool) Description() string {
|
||||
return "Returns a media artifact that the agent can forward or save later"
|
||||
}
|
||||
|
||||
func (m *mediaArtifactTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
ref, err := m.store.Store(m.path, media.MediaMeta{
|
||||
Filename: filepath.Base(m.path),
|
||||
ContentType: "image/png",
|
||||
Source: "test:media_artifact_tool",
|
||||
}, "test:media_artifact")
|
||||
if err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
return tools.MediaResult("Artifact created.", []string{ref})
|
||||
}
|
||||
|
||||
type toolLimitTestTool struct{}
|
||||
|
||||
func (m *toolLimitTestTool) Name() string {
|
||||
|
||||
+75
-9
@@ -206,6 +206,40 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
return false
|
||||
}
|
||||
|
||||
// preSendMedia handles typing stop, reaction undo, and placeholder cleanup
|
||||
// before sending media attachments. Unlike preSend for text messages, media
|
||||
// delivery never edits the placeholder because there is no text payload to
|
||||
// replace it with; it only attempts to delete the placeholder when possible.
|
||||
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
|
||||
key := name + ":" + msg.ChatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(typingEntry); ok {
|
||||
entry.stop() // idempotent, safe
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Undo reaction
|
||||
if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(reactionEntry); ok {
|
||||
entry.undo() // idempotent, safe
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Clear any finalized stream marker for this chat before media delivery.
|
||||
m.streamActive.LoadAndDelete(key)
|
||||
|
||||
// 4. Delete placeholder if present.
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
|
||||
m := &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
@@ -774,7 +808,7 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m.sendMediaWithRetry(ctx, name, w, msg)
|
||||
_ = m.sendMediaWithRetry(ctx, name, w, msg)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -782,26 +816,37 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor
|
||||
}
|
||||
|
||||
// sendMediaWithRetry sends a media message through the channel with rate limiting and
|
||||
// retry logic. If the channel does not implement MediaSender, it silently skips.
|
||||
func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) {
|
||||
// retry logic. It returns nil on success, or the last error after retries,
|
||||
// including when the channel does not support MediaSender.
|
||||
func (m *Manager) sendMediaWithRetry(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
w *channelWorker,
|
||||
msg bus.OutboundMediaMessage,
|
||||
) error {
|
||||
ms, ok := w.ch.(MediaSender)
|
||||
if !ok {
|
||||
logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{
|
||||
err := fmt.Errorf("channel %q does not support media sending", name)
|
||||
logger.WarnCF("channels", "Channel does not support MediaSender", map[string]any{
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// Pre-send: stop typing and clean up any placeholder before sending media.
|
||||
m.preSendMedia(ctx, name, msg, w.ch)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
lastErr = ms.SendMedia(ctx, msg)
|
||||
if lastErr == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Permanent failures — don't retry
|
||||
@@ -820,7 +865,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
|
||||
case <-time.After(rateLimitDelay):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,7 +874,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -840,6 +885,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// runTTLJanitor periodically scans the typingStops and placeholders maps
|
||||
@@ -1032,6 +1078,26 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMedia sends outbound media synchronously through the channel worker's
|
||||
// rate limiter and retry logic. It blocks until the media is delivered (or all
|
||||
// retries are exhausted), which preserves ordering when later agent behavior
|
||||
// depends on actual media delivery.
|
||||
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
}
|
||||
|
||||
return m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
|
||||
}
|
||||
|
||||
func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error {
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[channelName]
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -43,6 +44,40 @@ func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, conten
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockMediaChannel struct {
|
||||
mockChannel
|
||||
sendMediaFn func(ctx context.Context, msg bus.OutboundMediaMessage) error
|
||||
sentMediaMessages []bus.OutboundMediaMessage
|
||||
}
|
||||
|
||||
func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
m.sentMediaMessages = append(m.sentMediaMessages, msg)
|
||||
if m.sendMediaFn != nil {
|
||||
return m.sendMediaFn(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeletingMediaChannel struct {
|
||||
mockMediaChannel
|
||||
deleteCalls int
|
||||
lastDeleted struct {
|
||||
chatID string
|
||||
messageID string
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDeletingMediaChannel) DeleteMessage(
|
||||
_ context.Context,
|
||||
chatID string,
|
||||
messageID string,
|
||||
) error {
|
||||
m.deleteCalls++
|
||||
m.lastDeleted.chatID = chatID
|
||||
m.lastDeleted.messageID = messageID
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTestManager creates a minimal Manager suitable for unit tests.
|
||||
func newTestManager() *Manager {
|
||||
return &Manager{
|
||||
@@ -208,6 +243,125 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
|
||||
callCount++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 SendMedia call, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_PropagatesFailure(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
|
||||
return fmt.Errorf("bad upload: %w", ErrSendFailed)
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error")
|
||||
}
|
||||
if !errors.Is(err, ErrSendFailed) {
|
||||
t.Fatalf("expected ErrSendFailed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error for unsupported channel")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not support media sending") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockDeletingMediaChannel{
|
||||
mockMediaChannel: mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
m.RecordPlaceholder("test", "chat1", "placeholder-1")
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
if ch.deleteCalls != 1 {
|
||||
t.Fatalf("expected placeholder delete to be called once, got %d", ch.deleteCalls)
|
||||
}
|
||||
if ch.lastDeleted.chatID != "chat1" || ch.lastDeleted.messageID != "placeholder-1" {
|
||||
t.Fatalf("unexpected placeholder deletion target: %+v", ch.lastDeleted)
|
||||
}
|
||||
if len(ch.sentMediaMessages) != 1 {
|
||||
t.Fatalf("expected media to be sent once, got %d", len(ch.sentMediaMessages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_UnknownError(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
|
||||
+269
-11
@@ -5,9 +5,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
@@ -25,6 +29,7 @@ type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
@@ -36,6 +41,10 @@ func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
@@ -218,13 +227,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
// Extract text content from result
|
||||
output := extractContentText(result.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
IsError: false,
|
||||
}
|
||||
return t.normalizeResultContent(ctx, result.Content)
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
@@ -233,14 +236,269 @@ func extractContentText(content []mcp.Content) string {
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, v.Text)
|
||||
parts = append(parts, sanitizeToolLLMContent(v.Text))
|
||||
case *mcp.ImageContent:
|
||||
// For images, just indicate that an image was returned
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType)))
|
||||
case *mcp.AudioContent:
|
||||
parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType)))
|
||||
case *mcp.ResourceLink:
|
||||
parts = append(parts, summarizeResourceLink(v))
|
||||
case *mcp.EmbeddedResource:
|
||||
parts = append(parts, summarizeEmbeddedResource(v))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
return sanitizeToolLLMContent(strings.Join(parts, "\n"))
|
||||
}
|
||||
|
||||
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
|
||||
llmParts := make([]string, 0, len(content))
|
||||
mediaRefs := make([]string, 0, len(content))
|
||||
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
text := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
|
||||
if text != "" {
|
||||
llmParts = append(llmParts, text)
|
||||
}
|
||||
case *mcp.ImageContent:
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"image",
|
||||
normalizedMIMEType(v.MIMEType),
|
||||
v.Data,
|
||||
v.Annotations,
|
||||
)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
case *mcp.AudioContent:
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"audio",
|
||||
normalizedMIMEType(v.MIMEType),
|
||||
v.Data,
|
||||
v.Annotations,
|
||||
)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
case *mcp.ResourceLink:
|
||||
llmParts = append(llmParts, summarizeResourceLink(v))
|
||||
case *mcp.EmbeddedResource:
|
||||
ref, note := t.storeEmbeddedResource(ctx, v)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
default:
|
||||
llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v))
|
||||
}
|
||||
}
|
||||
|
||||
result := &ToolResult{
|
||||
ForLLM: strings.Join(compactStrings(llmParts), "\n"),
|
||||
Media: mediaRefs,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "", "[MCP returned an embedded resource without data.]"
|
||||
}
|
||||
|
||||
resource := content.Resource
|
||||
if len(resource.Blob) > 0 {
|
||||
return t.storeBinaryContent(
|
||||
ctx,
|
||||
"resource",
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
resource.Blob,
|
||||
content.Annotations,
|
||||
)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(resource.Text) != "" {
|
||||
return "", sanitizeToolLLMContent(resource.Text)
|
||||
}
|
||||
|
||||
return "", summarizeEmbeddedResource(content)
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeBinaryContent(
|
||||
ctx context.Context,
|
||||
kind string,
|
||||
mimeType string,
|
||||
data []byte,
|
||||
annotations *mcp.Annotations,
|
||||
) (string, string) {
|
||||
if len(data) == 0 {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType)
|
||||
}
|
||||
if !annotationsAllowUser(annotations) {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) for non-user audience; omitted from model context.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
if t.mediaStore == nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
if channel == "" || chatID == "" {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context because no target chat was available.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
dir := media.TempDir()
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf(
|
||||
"tool:mcp:%s:%s:%s:%d",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
channel,
|
||||
chatID,
|
||||
time.Now().UnixNano(),
|
||||
)
|
||||
filename := fmt.Sprintf(
|
||||
"%s_%s%s",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
ext,
|
||||
)
|
||||
|
||||
ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: mimeType,
|
||||
Source: fmt.Sprintf(
|
||||
"tool:mcp:%s:%s",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
),
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be registered as media.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
return ref, fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
func summarizeResourceLink(content *mcp.ResourceLink) string {
|
||||
if content == nil {
|
||||
return "[MCP returned an empty resource link.]"
|
||||
}
|
||||
|
||||
parts := []string{"[MCP returned resource link"}
|
||||
if content.Name != "" {
|
||||
parts = append(parts, fmt.Sprintf("name=%q", content.Name))
|
||||
}
|
||||
if content.URI != "" {
|
||||
parts = append(parts, fmt.Sprintf("uri=%q", content.URI))
|
||||
}
|
||||
if content.MIMEType != "" {
|
||||
parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType))
|
||||
}
|
||||
if content.Description != "" {
|
||||
desc := strings.TrimSpace(content.Description)
|
||||
if len(desc) > 200 {
|
||||
desc = desc[:200] + "..."
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("description=%q", desc))
|
||||
}
|
||||
return strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "[MCP returned an embedded resource.]"
|
||||
}
|
||||
|
||||
resource := content.Resource
|
||||
if resource.URI != "" {
|
||||
return fmt.Sprintf(
|
||||
"[MCP returned embedded resource %q (%s).]",
|
||||
resource.URI,
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
|
||||
}
|
||||
|
||||
func annotationsAllowUser(annotations *mcp.Annotations) bool {
|
||||
if annotations == nil || len(annotations.Audience) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, audience := range annotations.Audience {
|
||||
if strings.EqualFold(string(audience), "user") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizedMIMEType(mimeType string) string {
|
||||
if strings.TrimSpace(mimeType) == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return mimeType
|
||||
}
|
||||
|
||||
func compactStrings(parts []string) []string {
|
||||
compact := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "" {
|
||||
continue
|
||||
}
|
||||
compact = append(compact, part)
|
||||
}
|
||||
return compact
|
||||
}
|
||||
|
||||
@@ -3,10 +3,14 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
@@ -490,3 +494,143 @@ func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("fake-image-bytes"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
if result.ResponseHandled {
|
||||
t.Fatal("expected MCP image artifact not to mark response as handled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "stored as a local media artifact") {
|
||||
t.Fatalf("expected local media artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
|
||||
path, meta, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media ref to resolve: %v", err)
|
||||
}
|
||||
if meta.ContentType != "image/png" {
|
||||
t.Fatalf("expected image/png content type, got %q", meta.ContentType)
|
||||
}
|
||||
if filepath.Ext(path) != ".png" {
|
||||
t.Fatalf("expected png temp file, got %q", path)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != "fake-image-bytes" {
|
||||
t.Fatalf("expected stored media bytes to match input, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.EmbeddedResource{
|
||||
Resource: &mcp.ResourceContents{
|
||||
URI: "file:///tmp/report.png",
|
||||
MIMEType: "image/png",
|
||||
Blob: []byte("blob-bytes"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
|
||||
}
|
||||
path, _, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media ref to resolve: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != "blob-bytes" {
|
||||
t.Fatalf("expected stored blob bytes to match input, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("assistant-only"),
|
||||
MIMEType: "image/png",
|
||||
Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 0 {
|
||||
t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "non-user audience") {
|
||||
t.Fatalf("expected audience note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: strings.Repeat("QUJD", 400)},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if result.ForLLM != largeBase64OmittedMessage {
|
||||
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
const (
|
||||
largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]"
|
||||
inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]"
|
||||
inlineMediaStoredMessage = "[Tool returned inline media content (%s); omitted from model context and registered as a media attachment.]"
|
||||
)
|
||||
|
||||
var (
|
||||
inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`)
|
||||
inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`)
|
||||
)
|
||||
|
||||
func normalizeToolResult(
|
||||
result *ToolResult,
|
||||
toolName string,
|
||||
store media.MediaStore,
|
||||
channel string,
|
||||
chatID string,
|
||||
) *ToolResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
notes := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
if store != nil && channel != "" && chatID != "" {
|
||||
var refs []string
|
||||
var extractedNotes []string
|
||||
|
||||
result.ForLLM, refs, extractedNotes = extractInlineMediaRefs(
|
||||
result.ForLLM,
|
||||
toolName,
|
||||
store,
|
||||
channel,
|
||||
chatID,
|
||||
seen,
|
||||
)
|
||||
result.Media = append(result.Media, refs...)
|
||||
notes = append(notes, extractedNotes...)
|
||||
|
||||
result.ForUser, refs, extractedNotes = extractInlineMediaRefs(
|
||||
result.ForUser,
|
||||
toolName,
|
||||
store,
|
||||
channel,
|
||||
chatID,
|
||||
seen,
|
||||
)
|
||||
result.Media = append(result.Media, refs...)
|
||||
notes = append(notes, extractedNotes...)
|
||||
}
|
||||
|
||||
result.ForLLM = sanitizeToolLLMContent(result.ForLLM)
|
||||
|
||||
if len(result.Media) > 0 && len(notes) > 0 {
|
||||
if strings.TrimSpace(result.ForLLM) == "" {
|
||||
result.ForLLM = strings.Join(notes, "\n")
|
||||
} else {
|
||||
result.ForLLM = strings.TrimSpace(result.ForLLM) + "\n" + strings.Join(notes, "\n")
|
||||
}
|
||||
}
|
||||
if len(result.Media) > 0 && strings.TrimSpace(result.ForLLM) == "" {
|
||||
result.ForLLM = "[Tool returned media content; omitted from model context and registered as a media attachment.]"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func sanitizeToolLLMContent(text string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return text
|
||||
}
|
||||
if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) {
|
||||
cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "")
|
||||
cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
if cleaned == "" {
|
||||
return inlineMediaOmittedMessage
|
||||
}
|
||||
return cleaned + "\n" + inlineMediaOmittedMessage
|
||||
}
|
||||
if looksLikeLargeBase64Payload(trimmed) {
|
||||
return largeBase64OmittedMessage
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func looksLikeLargeBase64Payload(text string) bool {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if len(trimmed) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
nonSpace := 0
|
||||
base64Like := 0
|
||||
spaceCount := 0
|
||||
|
||||
for _, r := range trimmed {
|
||||
if unicode.IsSpace(r) {
|
||||
spaceCount++
|
||||
continue
|
||||
}
|
||||
nonSpace++
|
||||
if (r >= 'A' && r <= 'Z') ||
|
||||
(r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '+' || r == '/' || r == '=' {
|
||||
base64Like++
|
||||
}
|
||||
}
|
||||
|
||||
if nonSpace == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
ratio := float64(base64Like) / float64(nonSpace)
|
||||
return ratio >= 0.97 && spaceCount <= len(trimmed)/128
|
||||
}
|
||||
|
||||
func extractInlineMediaRefs(
|
||||
text string,
|
||||
toolName string,
|
||||
store media.MediaStore,
|
||||
channel string,
|
||||
chatID string,
|
||||
seen map[string]struct{},
|
||||
) (cleaned string, refs []string, notes []string) {
|
||||
cleaned = text
|
||||
|
||||
matches := inlineMarkdownDataURLRe.FindAllStringSubmatch(cleaned, -1)
|
||||
for _, match := range matches {
|
||||
if len(match) < 2 {
|
||||
continue
|
||||
}
|
||||
dataURL := match[1]
|
||||
ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
notes = append(notes, note)
|
||||
}
|
||||
cleaned = strings.ReplaceAll(cleaned, match[0], "")
|
||||
}
|
||||
|
||||
rawMatches := inlineRawDataURLRe.FindAllString(cleaned, -1)
|
||||
for _, dataURL := range rawMatches {
|
||||
ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
notes = append(notes, note)
|
||||
}
|
||||
cleaned = strings.ReplaceAll(cleaned, dataURL, "")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cleaned), refs, notes
|
||||
}
|
||||
|
||||
func storeInlineDataURL(
|
||||
toolName string,
|
||||
store media.MediaStore,
|
||||
channel string,
|
||||
chatID string,
|
||||
dataURL string,
|
||||
seen map[string]struct{},
|
||||
) (ref string, note string) {
|
||||
dataURL = strings.TrimSpace(dataURL)
|
||||
if _, ok := seen[dataURL]; ok {
|
||||
return "", ""
|
||||
}
|
||||
seen[dataURL] = struct{}{}
|
||||
|
||||
if !strings.HasPrefix(strings.ToLower(dataURL), "data:") {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
comma := strings.IndexByte(dataURL, ',')
|
||||
if comma <= 5 {
|
||||
return "", "[Tool returned inline media content that could not be parsed.]"
|
||||
}
|
||||
|
||||
metaPart := dataURL[:comma]
|
||||
payload := dataURL[comma+1:]
|
||||
if !strings.Contains(strings.ToLower(metaPart), ";base64") {
|
||||
return "", "[Tool returned inline media content that was not base64-encoded.]"
|
||||
}
|
||||
|
||||
mimeType := strings.TrimSpace(strings.TrimPrefix(metaPart, "data:"))
|
||||
if semi := strings.IndexByte(mimeType, ';'); semi >= 0 {
|
||||
mimeType = mimeType[:semi]
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
|
||||
payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload)
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType)
|
||||
}
|
||||
|
||||
dir := media.TempDir()
|
||||
if err = os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(decoded); err != nil {
|
||||
tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
filename := sanitizeIdentifierComponent(toolName) + ext
|
||||
scope := fmt.Sprintf(
|
||||
"tool:inline:%s:%s:%s:%d",
|
||||
sanitizeIdentifierComponent(toolName),
|
||||
channel,
|
||||
chatID,
|
||||
time.Now().UnixNano(),
|
||||
)
|
||||
|
||||
ref, err = store.Store(tmpPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: mimeType,
|
||||
Source: fmt.Sprintf("tool:inline:%s", sanitizeIdentifierComponent(toolName)),
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType)
|
||||
}
|
||||
|
||||
return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType)
|
||||
}
|
||||
|
||||
func extensionForMIMEType(mimeType string) string {
|
||||
if mimeType == "" {
|
||||
return ".bin"
|
||||
}
|
||||
if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 {
|
||||
return exts[0]
|
||||
}
|
||||
|
||||
switch strings.ToLower(mimeType) {
|
||||
case "image/jpeg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "audio/wav", "audio/x-wav":
|
||||
return ".wav"
|
||||
case "audio/mpeg":
|
||||
return ".mp3"
|
||||
case "audio/ogg":
|
||||
return ".ogg"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
default:
|
||||
return filepath.Ext(mimeType)
|
||||
}
|
||||
}
|
||||
+34
-5
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -19,9 +20,14 @@ type ToolEntry struct {
|
||||
}
|
||||
|
||||
type ToolRegistry struct {
|
||||
tools map[string]*ToolEntry
|
||||
mu sync.RWMutex
|
||||
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
|
||||
tools map[string]*ToolEntry
|
||||
mu sync.RWMutex
|
||||
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
type mediaStoreAware interface {
|
||||
SetMediaStore(store media.MediaStore)
|
||||
}
|
||||
|
||||
func NewToolRegistry() *ToolRegistry {
|
||||
@@ -43,6 +49,9 @@ func (r *ToolRegistry) Register(tool Tool) {
|
||||
IsCore: true,
|
||||
TTL: 0, // Core tools do not use TTL
|
||||
}
|
||||
if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil {
|
||||
aware.SetMediaStore(r.mediaStore)
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
|
||||
}
|
||||
@@ -61,10 +70,27 @@ func (r *ToolRegistry) RegisterHidden(tool Tool) {
|
||||
IsCore: false,
|
||||
TTL: 0,
|
||||
}
|
||||
if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil {
|
||||
aware.SetMediaStore(r.mediaStore)
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
|
||||
}
|
||||
|
||||
// SetMediaStore injects a MediaStore into all registered tools that can
|
||||
// consume it, and remembers it for future registrations.
|
||||
func (r *ToolRegistry) SetMediaStore(store media.MediaStore) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.mediaStore = store
|
||||
for _, entry := range r.tools {
|
||||
if aware, ok := entry.Tool.(mediaStoreAware); ok {
|
||||
aware.SetMediaStore(store)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PromoteTools atomically sets the TTL for multiple non-core tools.
|
||||
// This prevents a concurrent TickTTL from decrementing between promotions.
|
||||
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
|
||||
@@ -238,6 +264,8 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
}
|
||||
}
|
||||
|
||||
result = normalizeToolResult(result, name, r.mediaStore, channel, chatID)
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log based on result type
|
||||
@@ -259,7 +287,7 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result.ForLLM),
|
||||
"result_length": len(result.ContentForLLM()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -354,7 +382,8 @@ func (r *ToolRegistry) Clone() *ToolRegistry {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
clone := &ToolRegistry{
|
||||
tools: make(map[string]*ToolEntry, len(r.tools)),
|
||||
tools: make(map[string]*ToolEntry, len(r.tools)),
|
||||
mediaStore: r.mediaStore,
|
||||
}
|
||||
for name, entry := range r.tools {
|
||||
clone.tools[name] = &ToolEntry{
|
||||
|
||||
@@ -3,10 +3,13 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -46,6 +49,15 @@ func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockMediaStoreAwareTool struct {
|
||||
mockRegistryTool
|
||||
store media.MediaStore
|
||||
}
|
||||
|
||||
func (m *mockMediaStoreAwareTool) SetMediaStore(store media.MediaStore) {
|
||||
m.store = store
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func newMockTool(name, desc string) *mockRegistryTool {
|
||||
@@ -621,3 +633,102 @@ func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) {
|
||||
t.Errorf("expected 'success', got %q", result2.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_SetMediaStore_PropagatesToExistingAndNewTools(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
existing := &mockMediaStoreAwareTool{
|
||||
mockRegistryTool: *newMockTool("existing", "existing tool"),
|
||||
}
|
||||
r.Register(existing)
|
||||
|
||||
r.SetMediaStore(store)
|
||||
if existing.store != store {
|
||||
t.Fatal("expected existing tool to receive media store")
|
||||
}
|
||||
|
||||
later := &mockMediaStoreAwareTool{
|
||||
mockRegistryTool: *newMockTool("later", "later tool"),
|
||||
}
|
||||
r.Register(later)
|
||||
|
||||
if later.store != store {
|
||||
t.Fatal("expected newly registered tool to inherit media store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SanitizesLargeBase64Payload(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
payload := strings.Repeat("QUJD", 400)
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "base64_tool",
|
||||
desc: "returns huge base64",
|
||||
params: map[string]any{},
|
||||
result: SilentResult(payload),
|
||||
})
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "base64_tool", nil, "telegram", "chat-1", nil)
|
||||
|
||||
if result.ForLLM != largeBase64OmittedMessage {
|
||||
t.Fatalf("expected sanitized payload, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_ExtractsInlineMediaDataURL(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
store := media.NewFileMediaStore()
|
||||
r.SetMediaStore(store)
|
||||
|
||||
payload := ""
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "inline_media_tool",
|
||||
desc: "returns inline data url",
|
||||
params: map[string]any{},
|
||||
result: SilentResult(payload),
|
||||
})
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "inline_media_tool", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "data:image/png;base64") {
|
||||
t.Fatalf("expected inline data URL to be stripped from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "registered as a media attachment") {
|
||||
t.Fatalf("expected delivery note in ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
|
||||
path, err := store.Resolve(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media ref to resolve: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("expected stored media file to exist: %v", err)
|
||||
}
|
||||
if filepath.Ext(path) != ".png" {
|
||||
t.Fatalf("expected stored inline media to use png extension, got %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SanitizesInlineMediaWithoutStore(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
|
||||
payload := "before  after"
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "inline_media_no_store",
|
||||
desc: "returns inline data url without store",
|
||||
params: map[string]any{},
|
||||
result: SilentResult(payload),
|
||||
})
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "inline_media_no_store", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if strings.Contains(result.ForLLM, "data:image/png;base64") {
|
||||
t.Fatalf("expected inline data URL to be removed from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, inlineMediaOmittedMessage) {
|
||||
t.Fatalf("expected inline media omission note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,16 @@ package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
handledToolLLMNote = "The requested output has already been delivered to the user in the current chat. Do not call send_file or any other delivery tool again. If you reply, provide only a brief confirmation."
|
||||
artifactPathsLLMNote = "Use `send_file` with one of these paths to send it to the user, or use file/exec tools to save it inside the workspace if requested."
|
||||
)
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
// async operations, user-facing messages, and error handling.
|
||||
@@ -43,6 +49,48 @@ type ToolResult struct {
|
||||
// Only populated by SubTurn executions; used by evaluator_optimizer
|
||||
// to carry stateful worker context across evaluation iterations.
|
||||
Messages []providers.Message `json:"-"`
|
||||
|
||||
// ArtifactTags exposes local artifact paths back to the LLM in a structured
|
||||
// form, e.g. "[file:/tmp/example.png]". This is used when a tool produced a
|
||||
// reusable local artifact but did not deliver it to the user yet.
|
||||
ArtifactTags []string `json:"artifact_tags,omitempty"`
|
||||
|
||||
// ResponseHandled indicates that this tool execution already satisfied the
|
||||
// user's request at the channel/output level, so the agent loop can stop
|
||||
// without a follow-up assistant response.
|
||||
ResponseHandled bool `json:"response_handled,omitempty"`
|
||||
}
|
||||
|
||||
// ContentForLLM returns the normalized textual content to append to the
|
||||
// conversation after a tool call. Errors fall back to Err when ForLLM is empty.
|
||||
func (tr *ToolResult) ContentForLLM() string {
|
||||
if tr == nil {
|
||||
return ""
|
||||
}
|
||||
content := tr.ForLLM
|
||||
if content == "" && tr.Err != nil {
|
||||
content = tr.Err.Error()
|
||||
}
|
||||
if tr.ResponseHandled {
|
||||
if content == "" {
|
||||
return handledToolLLMNote
|
||||
}
|
||||
if !strings.Contains(content, handledToolLLMNote) {
|
||||
content += "\n" + handledToolLLMNote
|
||||
}
|
||||
}
|
||||
if len(tr.ArtifactTags) > 0 {
|
||||
artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote
|
||||
if content == "" {
|
||||
content = artifactNote
|
||||
} else if !strings.Contains(content, artifactNote) {
|
||||
content += "\n" + artifactNote
|
||||
}
|
||||
}
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
@@ -167,3 +215,9 @@ func (tr *ToolResult) WithError(err error) *ToolResult {
|
||||
tr.Err = err
|
||||
return tr
|
||||
}
|
||||
|
||||
// WithResponseHandled marks the tool result as already delivered to the user.
|
||||
func (tr *ToolResult) WithResponseHandled() *ToolResult {
|
||||
tr.ResponseHandled = true
|
||||
return tr
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -227,3 +228,41 @@ func TestToolResultJSONStructure(t *testing.T) {
|
||||
t.Errorf("Expected silent false, got %v", parsed["silent"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultContentForLLM_AppendsHandledDeliveryNote(t *testing.T) {
|
||||
result := MediaResult("Screenshot attached.", []string{"media://example"}).WithResponseHandled()
|
||||
|
||||
content := result.ContentForLLM()
|
||||
if !strings.Contains(content, "Screenshot attached.") {
|
||||
t.Fatalf("expected original content in ContentForLLM, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, handledToolLLMNote) {
|
||||
t.Fatalf("expected handled delivery note in ContentForLLM, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultContentForLLM_UsesHandledDeliveryNoteWhenEmpty(t *testing.T) {
|
||||
result := (&ToolResult{}).WithResponseHandled()
|
||||
|
||||
if got := result.ContentForLLM(); got != handledToolLLMNote {
|
||||
t.Fatalf("ContentForLLM() = %q, want %q", got, handledToolLLMNote)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultContentForLLM_AppendsArtifactPaths(t *testing.T) {
|
||||
result := &ToolResult{
|
||||
ForLLM: "Artifact created.",
|
||||
ArtifactTags: []string{"[file:/tmp/example.png]"},
|
||||
}
|
||||
|
||||
content := result.ContentForLLM()
|
||||
if !strings.Contains(content, "Artifact created.") {
|
||||
t.Fatalf("expected original content in ContentForLLM, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "Local artifact paths: [file:/tmp/example.png]") {
|
||||
t.Fatalf("expected artifact path note in ContentForLLM, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, artifactPathsLLMNote) {
|
||||
t.Fatalf("expected artifact guidance note in ContentForLLM, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
|
||||
}
|
||||
|
||||
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref})
|
||||
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
// detectMediaType determines the MIME type of a file.
|
||||
|
||||
@@ -104,6 +104,9 @@ func TestSendFileTool_Success(t *testing.T) {
|
||||
if result.Media[0][:8] != "media://" {
|
||||
t.Errorf("expected media:// ref, got %q", result.Media[0])
|
||||
}
|
||||
if !result.ResponseHandled {
|
||||
t.Fatal("expected send_file success to mark response handled")
|
||||
}
|
||||
|
||||
_, meta, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
|
||||
@@ -159,10 +159,7 @@ func RunToolLoop(
|
||||
|
||||
// Append results in original order
|
||||
for _, r := range results {
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
}
|
||||
contentForLLM := r.result.ContentForLLM()
|
||||
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "tool",
|
||||
|
||||
Reference in New Issue
Block a user